Skip to content

Commit 2ea74b3

Browse files
authored
Allow Multi File Select on Plot Wizard (#4748)
1 parent c6e227c commit 2ea74b3

File tree

11 files changed

+504
-149
lines changed

11 files changed

+504
-149
lines changed

extension/src/fileSystem/index.test.ts

Lines changed: 126 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import {
2323
getPidFromFile,
2424
getEntryFromJsonFile,
2525
addPlotToDvcYamlFile,
26-
loadDataFile
26+
loadDataFiles
2727
} from '.'
2828
import { dvcDemoPath } from '../test/util'
2929
import { DOT_DVC } from '../cli/dvc/constants'
@@ -63,17 +63,22 @@ beforeEach(() => {
6363
jest.resetAllMocks()
6464
})
6565

66-
describe('loadDataFile', () => {
66+
describe('loadDataFiles', () => {
6767
it('should load in csv file contents', async () => {
6868
const mockCsvContent = ['epoch,acc', '10,0.69', '11,0.345'].join('\n')
6969

7070
mockedReadFileSync.mockReturnValueOnce(mockCsvContent)
7171

72-
const result = await loadDataFile('values.csv')
72+
const result = await loadDataFiles(['values.csv'])
7373

7474
expect(result).toStrictEqual([
75-
{ acc: 0.69, epoch: 10 },
76-
{ acc: 0.345, epoch: 11 }
75+
{
76+
data: [
77+
{ acc: 0.69, epoch: 10 },
78+
{ acc: 0.345, epoch: 11 }
79+
],
80+
file: 'values.csv'
81+
}
7782
])
7883
})
7984

@@ -85,11 +90,16 @@ describe('loadDataFile', () => {
8590

8691
mockedReadFileSync.mockReturnValueOnce(mockJsonContent)
8792

88-
const result = await loadDataFile('values.json')
93+
const result = await loadDataFiles(['values.json'])
8994

9095
expect(result).toStrictEqual([
91-
{ acc: 0.69, epoch: 10 },
92-
{ acc: 0.345, epoch: 11 }
96+
{
97+
data: [
98+
{ acc: 0.69, epoch: 10 },
99+
{ acc: 0.345, epoch: 11 }
100+
],
101+
file: 'values.json'
102+
}
93103
])
94104
})
95105

@@ -98,11 +108,16 @@ describe('loadDataFile', () => {
98108

99109
mockedReadFileSync.mockReturnValueOnce(mockTsvContent)
100110

101-
const result = await loadDataFile('values.tsv')
111+
const result = await loadDataFiles(['values.tsv'])
102112

103113
expect(result).toStrictEqual([
104-
{ acc: 0.69, epoch: 10 },
105-
{ acc: 0.345, epoch: 11 }
114+
{
115+
data: [
116+
{ acc: 0.69, epoch: 10 },
117+
{ acc: 0.345, epoch: 11 }
118+
],
119+
file: 'values.tsv'
120+
}
106121
])
107122
})
108123

@@ -115,15 +130,47 @@ describe('loadDataFile', () => {
115130

116131
mockedReadFileSync.mockReturnValueOnce(mockYamlContent)
117132

118-
const result = await loadDataFile('dvc.yaml')
133+
const result = await loadDataFiles(['dvc.yaml'])
119134

120-
expect(result).toStrictEqual({
121-
stages: {
122-
train: {
123-
cmd: 'python train.py'
124-
}
135+
expect(result).toStrictEqual([
136+
{
137+
data: {
138+
stages: {
139+
train: {
140+
cmd: 'python train.py'
141+
}
142+
}
143+
},
144+
file: 'dvc.yaml'
125145
}
126-
})
146+
])
147+
})
148+
149+
it('should load in the contents of multiple files', async () => {
150+
const mockTsvContent = ['epoch\tacc', '10\t0.69', '11\t0.345'].join('\n')
151+
const mockCsvContent = ['epoch2,acc2', '10,0.679', '11,0.3'].join('\n')
152+
153+
mockedReadFileSync.mockReturnValueOnce(mockTsvContent)
154+
mockedReadFileSync.mockReturnValueOnce(mockCsvContent)
155+
156+
const result = await loadDataFiles(['values.tsv', 'values2.csv'])
157+
158+
expect(result).toStrictEqual([
159+
{
160+
data: [
161+
{ acc: 0.69, epoch: 10 },
162+
{ acc: 0.345, epoch: 11 }
163+
],
164+
file: 'values.tsv'
165+
},
166+
{
167+
data: [
168+
{ acc2: 0.679, epoch2: 10 },
169+
{ acc2: 0.3, epoch2: 11 }
170+
],
171+
file: 'values2.csv'
172+
}
173+
])
127174
})
128175

129176
it('should catch any errors thrown during file parsing', async () => {
@@ -133,11 +180,29 @@ describe('loadDataFile', () => {
133180
})
134181

135182
for (const file of dataFiles) {
136-
const resultWithErr = await loadDataFile(file)
183+
const resultWithErr = await loadDataFiles([file])
137184

138185
expect(resultWithErr).toStrictEqual(undefined)
139186
}
140187
})
188+
189+
it('should catch any errors thrown during the parsing of multiple files', async () => {
190+
const dataFiles = ['values.csv', 'file.tsv', 'file.json']
191+
const mockCsvContent = ['epoch,acc', '10,0.69', '11,0.345'].join('\n')
192+
const mockJsonContent = JSON.stringify([
193+
{ acc: 0.69, epoch: 10 },
194+
{ acc: 0.345, epoch: 11 }
195+
])
196+
mockedReadFileSync
197+
.mockReturnValueOnce(mockCsvContent)
198+
.mockImplementationOnce(() => {
199+
throw new Error('fake error')
200+
})
201+
.mockReturnValueOnce(mockJsonContent)
202+
203+
const resultWithErr = await loadDataFiles(dataFiles)
204+
expect(resultWithErr).toStrictEqual(undefined)
205+
})
141206
})
142207

143208
describe('writeJson', () => {
@@ -527,10 +592,11 @@ describe('addPlotToDvcYamlFile', () => {
527592
' eval/prc/test.json: precision'
528593
]
529594
const mockNewPlotLines = [
530-
' - data.json:',
595+
' - simple_plot:',
531596
' template: simple',
532597
' x: epochs',
533-
' y: accuracy'
598+
' y:',
599+
' data.json: accuracy'
534600
]
535601
it('should add a plots list with the new plot if the dvc.yaml file has no plots', () => {
536602
const mockDvcYamlContent = mockStagesLines.join('\n')
@@ -541,10 +607,37 @@ describe('addPlotToDvcYamlFile', () => {
541607
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)
542608

543609
addPlotToDvcYamlFile('/', {
544-
dataFile: '/data.json',
545610
template: 'simple',
546-
x: 'epochs',
547-
y: 'accuracy'
611+
x: { file: '/data.json', key: 'epochs' },
612+
y: { file: '/data.json', key: 'accuracy' }
613+
})
614+
615+
expect(mockedWriteFileSync).toHaveBeenCalledWith(
616+
'//dvc.yaml',
617+
mockDvcYamlContent + mockPlotYamlContent
618+
)
619+
})
620+
621+
it('should add the new plot with fields coming from different files', () => {
622+
const mockDvcYamlContent = mockStagesLines.join('\n')
623+
const mockPlotYamlContent = [
624+
'',
625+
'plots:',
626+
' - simple_plot:',
627+
' template: simple',
628+
' x:',
629+
' data.json: epochs',
630+
' y:',
631+
' acc.json: accuracy',
632+
''
633+
].join('\n')
634+
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)
635+
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)
636+
637+
addPlotToDvcYamlFile('/', {
638+
template: 'simple',
639+
x: { file: '/data.json', key: 'epochs' },
640+
y: { file: '/acc.json', key: 'accuracy' }
548641
})
549642

550643
expect(mockedWriteFileSync).toHaveBeenCalledWith(
@@ -560,10 +653,9 @@ describe('addPlotToDvcYamlFile', () => {
560653
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent.join('\n'))
561654

562655
addPlotToDvcYamlFile('/', {
563-
dataFile: '/data.json',
564656
template: 'simple',
565-
x: 'epochs',
566-
y: 'accuracy'
657+
x: { file: '/data.json', key: 'epochs' },
658+
y: { file: '/data.json', key: 'accuracy' }
567659
})
568660

569661
mockDvcYamlContent.splice(7, 0, ...mockPlotYamlContent)
@@ -583,10 +675,9 @@ describe('addPlotToDvcYamlFile', () => {
583675
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)
584676

585677
addPlotToDvcYamlFile('/', {
586-
dataFile: '/data.json',
587678
template: 'simple',
588-
x: 'epochs',
589-
y: 'accuracy'
679+
x: { file: '/data.json', key: 'epochs' },
680+
y: { file: '/data.json', key: 'accuracy' }
590681
})
591682

592683
expect(mockedWriteFileSync).toHaveBeenCalledWith(
@@ -610,20 +701,20 @@ describe('addPlotToDvcYamlFile', () => {
610701
].join('\n')
611702
const mockPlotYamlContent = [
612703
'',
613-
' - data.json:',
704+
' - simple_plot:',
614705
' template: simple',
615706
' x: epochs',
616-
' y: accuracy',
707+
' y:',
708+
' data.json: accuracy',
617709
''
618710
].join('\n')
619711
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)
620712
mockedReadFileSync.mockReturnValueOnce(mockDvcYamlContent)
621713

622714
addPlotToDvcYamlFile('/', {
623-
dataFile: '/data.json',
624715
template: 'simple',
625-
x: 'epochs',
626-
y: 'accuracy'
716+
x: { file: '/data.json', key: 'epochs' },
717+
y: { file: '/data.json', key: 'accuracy' }
627718
})
628719

629720
expect(mockedWriteFileSync).toHaveBeenCalledWith(

extension/src/fileSystem/index.ts

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,21 +214,31 @@ const loadYamlAsDoc = (
214214
}
215215
}
216216

217+
const getPlotYamlObj = (cwd: string, plot: PlotConfigData) => {
218+
const { x, y, template } = plot
219+
const plotName = `${template}_plot`
220+
return {
221+
[plotName]: {
222+
template,
223+
x: x.file === y.file ? x.key : { [relative(cwd, x.file)]: x.key },
224+
y: { [relative(cwd, y.file)]: y.key }
225+
}
226+
}
227+
}
228+
217229
const getPlotsYaml = (
218230
cwd: string,
219231
plotObj: PlotConfigData,
220232
indentSearchLines: string[]
221233
) => {
222-
const { dataFile, ...plot } = plotObj
223-
const plotName = relative(cwd, dataFile)
224234
const indentReg = /^( +)[^ ]/
225235
const indentLine = indentSearchLines.find(line => indentReg.test(line)) || ''
226236
const spacesMatches = indentLine.match(indentReg)
227237
const spaces = spacesMatches?.[1]
228238

229239
return yaml
230240
.stringify(
231-
{ plots: [{ [plotName]: plot }] },
241+
{ plots: [getPlotYamlObj(cwd, plotObj)] },
232242
{ indent: spaces ? spaces.length : 2 }
233243
)
234244
.split('\n')
@@ -315,7 +325,7 @@ const loadTsv = (path: string) => {
315325
}
316326
}
317327

318-
export const loadDataFile = (file: string): unknown => {
328+
const loadDataFile = (file: string): unknown => {
319329
const ext = getFileExtension(file)
320330

321331
switch (ext) {
@@ -330,6 +340,22 @@ export const loadDataFile = (file: string): unknown => {
330340
}
331341
}
332342

343+
export const loadDataFiles = async (
344+
files: string[]
345+
): Promise<{ file: string; data: unknown }[] | undefined> => {
346+
const filesData: { file: string; data: unknown }[] = []
347+
for (const file of files) {
348+
const data = await loadDataFile(file)
349+
350+
if (!data) {
351+
return undefined
352+
}
353+
354+
filesData.push({ data, file })
355+
}
356+
return filesData
357+
}
358+
333359
export const writeJson = <
334360
T extends Record<string, unknown> | Array<Record<string, unknown>>
335361
>(

extension/src/pipeline/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ export class Pipeline extends DeferredDisposable {
119119
return
120120
}
121121

122-
const plotConfiguration = await pickPlotConfiguration()
122+
const plotConfiguration = await pickPlotConfiguration(cwd)
123123

124124
if (!plotConfiguration) {
125125
return

0 commit comments

Comments
 (0)