Skip to content

Commit e718d3e

Browse files
authored
Add "image by step" plots to comparison section (#4319)
1 parent 3cd5400 commit e718d3e

40 files changed

+685
-179
lines changed

extension/src/cli/dvc/constants.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ export const DEFAULT_CURRENT_BRANCH_COMMITS_TO_SHOW = 3
1919
export const DEFAULT_OTHER_BRANCH_COMMITS_TO_SHOW = 1
2020
export const NUM_OF_COMMITS_TO_INCREASE = 2
2121

22+
export const MULTI_IMAGE_PATH_REG = /[^/]+[/\\]\d+\.[a-z]+$/i
23+
2224
export enum Command {
2325
ADD = 'add',
2426
CHECKOUT = 'checkout',

extension/src/cli/dvc/index.test.ts

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import { join } from 'path'
12
import { EventEmitter } from 'vscode'
23
import { Disposable, Disposer } from '@hediet/std/disposable'
34
import { DvcCli } from '.'
4-
import { Command } from './constants'
5+
import { Command, MULTI_IMAGE_PATH_REG } from './constants'
56
import { CliResult, CliStarted, typeCheckCommands } from '..'
67
import { getProcessEnv } from '../../env'
78
import { createProcess } from '../../process/execution'
@@ -52,6 +53,53 @@ describe('typeCheckCommands', () => {
5253
})
5354
})
5455

56+
describe('Comparison Multi Image Regex', () => {
57+
it('should match a nested image group directory', () => {
58+
expect(
59+
MULTI_IMAGE_PATH_REG.test(
60+
join(
61+
'extremely',
62+
'super',
63+
'super',
64+
'super',
65+
'nested',
66+
'image',
67+
'768.svg'
68+
)
69+
)
70+
).toBe(true)
71+
})
72+
73+
it('should match directories with spaces or special characters', () => {
74+
expect(MULTI_IMAGE_PATH_REG.test(join('mis classified', '5.png'))).toBe(
75+
true
76+
)
77+
78+
expect(MULTI_IMAGE_PATH_REG.test(join('misclassified#^', '5.png'))).toBe(
79+
true
80+
)
81+
})
82+
83+
it('should match different types of images', () => {
84+
const imageFormats = ['svg', 'png', 'jpg', 'jpeg']
85+
for (const format of imageFormats) {
86+
expect(
87+
MULTI_IMAGE_PATH_REG.test(join('misclassified', `5.${format}`))
88+
).toBe(true)
89+
}
90+
})
91+
92+
it('should not match files that include none digits or do not have a file extension', () => {
93+
expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', 'five.png'))).toBe(
94+
false
95+
)
96+
expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', '5 4.png'))).toBe(
97+
false
98+
)
99+
expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', '5'))).toBe(false)
100+
})
101+
})
102+
55103
describe('executeDvcProcess', () => {
56104
it('should pass the correct details to the underlying process given no path to the cli or python binary path', async () => {
57105
const existingPath = joinEnvPath(

extension/src/fileSystem/util.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { sep } from 'path'
1+
import { sep, parse } from 'path'
22

33
export const getPathArray = (path: string): string[] => path.split(sep)
44

@@ -18,3 +18,5 @@ export const getParent = (pathArray: string[], idx: number) => {
1818

1919
export const removeTrailingSlash = (path: string): string =>
2020
path.endsWith(sep) ? path.slice(0, -1) : path
21+
22+
export const getFileNameWithoutExt = (path: string) => parse(path).name

extension/src/plots/model/collect.test.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,18 @@ describe('collectData', () => {
8484
expect(Object.keys(comparisonData.main)).toStrictEqual([
8585
join('plots', 'acc.png'),
8686
heatmapPlot,
87-
join('plots', 'loss.png')
87+
join('plots', 'loss.png'),
88+
join('plots', 'image')
8889
])
8990

9091
const testBranchHeatmap = comparisonData['test-branch'][heatmapPlot]
9192

9293
expect(testBranchHeatmap).toBeDefined()
93-
expect(testBranchHeatmap).toStrictEqual(
94+
expect(testBranchHeatmap).toStrictEqual([
9495
plotsDiffFixture.data[heatmapPlot].find(({ revisions }) =>
9596
sameContents(revisions as string[], ['test-branch'])
9697
)
97-
)
98+
])
9899
})
99100
})
100101

extension/src/plots/model/collect.ts

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import {
1212
TemplatePlotSection,
1313
PlotsType,
1414
CustomPlotData,
15-
CustomPlotValues
15+
CustomPlotValues,
16+
ComparisonRevisionData,
17+
ComparisonPlotImg
1618
} from '../webview/contract'
1719
import { PlotsOutput } from '../../cli/dvc/contract'
1820
import { splitColumnPath } from '../../experiments/columns/paths'
@@ -34,6 +36,12 @@ import {
3436
import { StrokeDashEncoding } from '../multiSource/constants'
3537
import { exists } from '../../fileSystem'
3638
import { hasKey } from '../../util/object'
39+
import { MULTI_IMAGE_PATH_REG } from '../../cli/dvc/constants'
40+
import {
41+
getFileNameWithoutExt,
42+
getParent,
43+
getPathArray
44+
} from '../../fileSystem/util'
3745

3846
export const getCustomPlotId = (metric: string, param: string) =>
3947
`custom-${metric}-${param}`
@@ -126,18 +134,31 @@ export type RevisionData = {
126134
[label: string]: RevisionPathData
127135
}
128136

137+
type ComparisonDataImgPlot = ImagePlot & { ind?: number }
138+
129139
export type ComparisonData = {
130140
[label: string]: {
131-
[path: string]: ImagePlot
141+
[path: string]: ComparisonDataImgPlot[]
132142
}
133143
}
134144

145+
const getMultiImagePath = (path: string) =>
146+
getParent(getPathArray(path), 0) as string
147+
148+
const getMultiImageInd = (path: string) => {
149+
const fileName = getFileNameWithoutExt(path)
150+
return Number(fileName)
151+
}
152+
135153
const collectImageData = (
136154
acc: ComparisonData,
137155
path: string,
138156
plot: ImagePlot
139157
) => {
158+
const isMultiImgPlot = MULTI_IMAGE_PATH_REG.test(path)
159+
const pathLabel = isMultiImgPlot ? getMultiImagePath(path) : path
140160
const id = plot.revisions?.[0]
161+
141162
if (!id) {
142163
return
143164
}
@@ -146,7 +167,17 @@ const collectImageData = (
146167
acc[id] = {}
147168
}
148169

149-
acc[id][path] = plot
170+
if (!acc[id][pathLabel]) {
171+
acc[id][pathLabel] = []
172+
}
173+
174+
const imgPlot: ComparisonDataImgPlot = { ...plot }
175+
176+
if (isMultiImgPlot) {
177+
imgPlot.ind = getMultiImageInd(path)
178+
}
179+
180+
acc[id][pathLabel].push(imgPlot)
150181
}
151182

152183
const collectDatapoints = (
@@ -202,6 +233,16 @@ const collectPathData = (acc: DataAccumulator, path: string, plots: Plot[]) => {
202233
}
203234
}
204235

236+
const sortComparisonImgPaths = (acc: DataAccumulator) => {
237+
for (const [label, paths] of Object.entries(acc.comparisonData)) {
238+
for (const path of Object.keys(paths)) {
239+
acc.comparisonData[label][path].sort(
240+
(img1, img2) => (img1.ind || 0) - (img2.ind || 0)
241+
)
242+
}
243+
}
244+
}
245+
205246
export const collectData = (output: PlotsOutput): DataAccumulator => {
206247
const { data } = output
207248
const acc = {
@@ -213,6 +254,72 @@ export const collectData = (output: PlotsOutput): DataAccumulator => {
213254
collectPathData(acc, path, plots)
214255
}
215256

257+
sortComparisonImgPaths(acc)
258+
259+
return acc
260+
}
261+
262+
type ComparisonPlotsAcc = { path: string; revisions: ComparisonRevisionData }[]
263+
264+
type GetComparisonPlotImg = (
265+
img: ImagePlot,
266+
id: string,
267+
path: string
268+
) => ComparisonPlotImg
269+
270+
const collectSelectedPathComparisonPlots = ({
271+
acc,
272+
comparisonData,
273+
path,
274+
selectedRevisionIds,
275+
getComparisonPlotImg
276+
}: {
277+
acc: ComparisonPlotsAcc
278+
comparisonData: ComparisonData
279+
path: string
280+
selectedRevisionIds: string[]
281+
getComparisonPlotImg: GetComparisonPlotImg
282+
}) => {
283+
const pathRevisions = {
284+
path,
285+
revisions: {} as ComparisonRevisionData
286+
}
287+
288+
for (const id of selectedRevisionIds) {
289+
const imgs = comparisonData[id]?.[path]
290+
pathRevisions.revisions[id] = {
291+
id,
292+
imgs: imgs
293+
? imgs.map(img => getComparisonPlotImg(img, id, path))
294+
: [{ errors: undefined, loading: false, url: undefined }]
295+
}
296+
}
297+
acc.push(pathRevisions)
298+
}
299+
300+
export const collectSelectedComparisonPlots = ({
301+
comparisonData,
302+
paths,
303+
selectedRevisionIds,
304+
getComparisonPlotImg
305+
}: {
306+
comparisonData: ComparisonData
307+
paths: string[]
308+
selectedRevisionIds: string[]
309+
getComparisonPlotImg: GetComparisonPlotImg
310+
}) => {
311+
const acc: ComparisonPlotsAcc = []
312+
313+
for (const path of paths) {
314+
collectSelectedPathComparisonPlots({
315+
acc,
316+
comparisonData,
317+
getComparisonPlotImg,
318+
path,
319+
selectedRevisionIds
320+
})
321+
}
322+
216323
return acc
217324
}
218325

extension/src/plots/model/index.ts

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ import {
1212
collectImageUrl,
1313
collectIdShas,
1414
collectSelectedTemplatePlotRawData,
15-
collectCustomPlotRawData
15+
collectCustomPlotRawData,
16+
collectSelectedComparisonPlots
1617
} from './collect'
1718
import { getRevisionSummaryColumns } from './util'
1819
import {
@@ -21,9 +22,7 @@ import {
2122
CustomPlotsOrderValue
2223
} from './custom'
2324
import {
24-
ComparisonPlots,
2525
Revision,
26-
ComparisonRevisionData,
2726
DEFAULT_SECTION_COLLAPSED,
2827
DEFAULT_SECTION_NB_ITEMS_PER_ROW_OR_WIDTH,
2928
PlotsSection,
@@ -33,7 +32,8 @@ import {
3332
DEFAULT_HEIGHT,
3433
DEFAULT_NB_ITEMS_PER_ROW,
3534
PlotHeight,
36-
SmoothPlotValues
35+
SmoothPlotValues,
36+
ImagePlot
3737
} from '../webview/contract'
3838
import {
3939
EXPERIMENT_WORKSPACE_ID,
@@ -427,37 +427,23 @@ export class PlotsModel extends ModelWithPersistence {
427427
paths: string[],
428428
selectedRevisionIds: string[]
429429
) {
430-
const acc: ComparisonPlots = []
431-
for (const path of paths) {
432-
this.collectSelectedPathComparisonPlots(acc, path, selectedRevisionIds)
433-
}
434-
return acc
435-
}
436-
437-
private collectSelectedPathComparisonPlots(
438-
acc: ComparisonPlots,
439-
path: string,
440-
selectedRevisionIds: string[]
441-
) {
442-
const pathRevisions = {
443-
path,
444-
revisions: {} as ComparisonRevisionData
445-
}
446-
447-
for (const id of selectedRevisionIds) {
448-
const image = this.comparisonData?.[id]?.[path]
449-
const errors = this.errors.getImageErrors(path, id)
450-
const fetched = this.fetchedRevs.has(id)
451-
const url = collectImageUrl(image, fetched)
452-
const loading = !fetched && !url
453-
pathRevisions.revisions[id] = {
454-
errors,
455-
id,
456-
loading,
457-
url
458-
}
459-
}
460-
acc.push(pathRevisions)
430+
return collectSelectedComparisonPlots({
431+
comparisonData: this.comparisonData,
432+
getComparisonPlotImg: (image: ImagePlot, id: string, path: string) => {
433+
const errors = this.errors.getImageErrors(path, id)
434+
const fetched = this.fetchedRevs.has(id)
435+
const url = collectImageUrl(image, fetched)
436+
const loading = !fetched && !url
437+
438+
return {
439+
errors,
440+
loading,
441+
url
442+
}
443+
},
444+
paths,
445+
selectedRevisionIds
446+
})
461447
}
462448

463449
private getSelectedTemplatePlots(

extension/src/plots/paths/collect.test.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ describe('collectPaths', () => {
4545
revisions: new Set(REVISIONS),
4646
type: new Set(['comparison'])
4747
},
48+
{
49+
hasChildren: false,
50+
parentPath: 'plots',
51+
path: join('plots', 'image'),
52+
revisions: new Set(REVISIONS),
53+
type: new Set(['comparison'])
54+
},
4855
{
4956
hasChildren: false,
5057
parentPath: 'logs',

0 commit comments

Comments
 (0)