Skip to content

Commit d245ba6

Browse files
authored
Accomodate params that are lists (#1818)
* add arrays to exp show data type * return string version of the array to the experiments table to match CLI * ensure we pass the right data when using modify and commands * unit test and improve quick pick * add string and boolean data types
1 parent 9981bd9 commit d245ba6

File tree

15 files changed

+113
-14
lines changed

15 files changed

+113
-14
lines changed

extension/src/cli/reader.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ export type StatusesOrAlwaysChanged = StageOrFileStatuses | 'always changed'
4949

5050
export type StatusOutput = Record<string, StatusesOrAlwaysChanged[]>
5151

52-
export type Value = string | number | boolean | null
52+
export type Value =
53+
| string
54+
| number
55+
| boolean
56+
| null
57+
| number[]
58+
| string[]
59+
| boolean[]
5360

5461
export interface ValueTreeOrError {
5562
data?: ValueTree

extension/src/experiments/columns/collect.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ describe('collectColumns', () => {
416416
joinColumnPath(ColumnType.METRICS, 'summary.json', 'val_loss'),
417417
joinColumnPath(ColumnType.METRICS, 'summary.json', 'val_accuracy'),
418418
joinColumnPath(ColumnType.PARAMS, 'params.yaml'),
419+
joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'code_names'),
419420
joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'epochs'),
420421
joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'learning_rate'),
421422
joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'dvc_logs_dir'),

extension/src/experiments/columns/collect.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ const getValueType = (value: Value) => {
1717
if (value === null) {
1818
return 'null'
1919
}
20+
if (Array.isArray(value)) {
21+
return 'array'
22+
}
2023
return typeof value
2124
}
2225

extension/src/experiments/columns/tree.test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,14 @@ describe('ExperimentsColumnsTree', () => {
217217
path: paramsPath
218218
})
219219
expect(grandChildren).toStrictEqual([
220+
{
221+
collapsibleState: 0,
222+
description: undefined,
223+
dvcRoot: mockedDvcRoot,
224+
iconPath: mockedSelectedCheckbox,
225+
label: 'code_names',
226+
path: appendColumnToPath(paramsPath, 'code_names')
227+
},
220228
{
221229
collapsibleState: 0,
222230
description: undefined,

extension/src/experiments/columns/walk.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const walkValueTree = (
2727
ancestors: string[] = []
2828
) => {
2929
for (const [key, value] of Object.entries(tree)) {
30-
if (value && typeof value === 'object') {
30+
if (value && !Array.isArray(value) && typeof value === 'object') {
3131
walkValueTree(value, meta, onValue, [...ancestors, key])
3232
} else {
3333
onValue(key, value, meta, ancestors)

extension/src/experiments/model/queue/collect.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ describe('collectFlatExperimentParams', () => {
77
it('should flatten the params into an array', () => {
88
const params = collectFlatExperimentParams(rowsFixture[0].params)
99
expect(params).toStrictEqual([
10+
{ path: appendColumnToPath('params.yaml', 'code_names'), value: [0, 1] },
1011
{ path: appendColumnToPath('params.yaml', 'epochs'), value: 2 },
1112
{
1213
path: appendColumnToPath('params.yaml', 'learning_rate'),

extension/src/experiments/model/queue/collect.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@ import { Columns } from '../../webview/contract'
44

55
export type Param = {
66
path: string
7-
value: number | string | boolean
7+
value: Value
88
}
99

1010
const collectFromParamsFile = (
11-
acc: { path: string; value: string | number | boolean }[],
11+
acc: { path: string; value: Value }[],
1212
key: string | undefined,
1313
value: Value | ValueTree,
1414
ancestors: string[] = []
1515
) => {
1616
const pathArray = [...ancestors, key].filter(Boolean) as string[]
1717

18-
if (typeof value === 'object') {
18+
if (!Array.isArray(value) && typeof value === 'object') {
1919
for (const [childKey, childValue] of Object.entries(value as ValueTree)) {
2020
collectFromParamsFile(acc, childKey, childValue, pathArray)
2121
}

extension/src/experiments/model/queue/quickPick.test.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,38 @@ describe('pickAndModifyParams', () => {
4848
const unchanged = { path: 'params.yaml:learning_rate', value: 2e-12 }
4949
const initialUserResponse = [
5050
{ path: 'params.yaml:dropout', value: 0.15 },
51-
{ path: 'params.yaml:process.threshold', value: 0.86 }
51+
{ path: 'params.yaml:process.threshold', value: 0.86 },
52+
{ path: 'params.yaml:code_names', value: [0, 1, 2] }
5253
]
5354
mockedQuickPickManyValues.mockResolvedValueOnce(initialUserResponse)
5455
const firstInput = '0.16'
5556
const secondInput = '0.87'
57+
const thirdInput = '[0,1,3]'
5658
mockedGetInput.mockResolvedValueOnce(firstInput)
5759
mockedGetInput.mockResolvedValueOnce(secondInput)
60+
mockedGetInput.mockResolvedValueOnce(thirdInput)
5861

5962
const paramsToQueue = await pickAndModifyParams([
6063
unchanged,
6164
...initialUserResponse
6265
])
6366

67+
expect(mockedGetInput).toBeCalledTimes(3)
68+
expect(mockedGetInput).toBeCalledWith(
69+
'Enter a Value for params.yaml:code_names',
70+
'[0,1,2]'
71+
)
72+
6473
expect(paramsToQueue).toStrictEqual([
6574
'-S',
6675
`params.yaml:dropout=${firstInput}`,
6776
'-S',
6877
`params.yaml:process.threshold=${secondInput}`,
6978
'-S',
79+
`params.yaml:code_names=${thirdInput}`,
80+
'-S',
7081
[unchanged.path, unchanged.value].join('=')
7182
])
72-
expect(mockedGetInput).toBeCalledTimes(2)
83+
expect(mockedGetInput).toBeCalledTimes(3)
7384
})
7485
})

extension/src/experiments/model/queue/quickPick.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ import { getInput } from '../../../vscode/inputBox'
44
import { Flag } from '../../../cli/constants'
55
import { definedAndNonEmpty } from '../../../util/array'
66
import { getEnterValueTitle, Title } from '../../../vscode/title'
7+
import { Value } from '../../../cli/reader'
8+
9+
const standardizeValue = (value: Value): string =>
10+
typeof value === 'object' ? JSON.stringify(value) : `${value}`
711

812
const pickParamsToModify = (params: Param[]): Thenable<Param[] | undefined> =>
913
quickPickManyValues<Param>(
1014
params.map(param => ({
11-
description: `${param.value}`,
15+
description: standardizeValue(param.value),
1216
label: param.path,
1317
picked: false,
1418
value: param
@@ -21,18 +25,21 @@ const pickNewParamValues = async (
2125
): Promise<string[] | undefined> => {
2226
const args: string[] = []
2327
for (const { path, value } of paramsToModify) {
24-
const input = await getInput(getEnterValueTitle(path), `${value}`)
28+
const input = await getInput(
29+
getEnterValueTitle(path),
30+
standardizeValue(value)
31+
)
2532
if (input === undefined) {
2633
return
2734
}
28-
args.push(Flag.SET_PARAM, [path, input.trim()].join('='))
35+
args.push(Flag.SET_PARAM, [path, standardizeValue(input.trim())].join('='))
2936
}
3037
return args
3138
}
3239

3340
const addUnchanged = (args: string[], unchanged: Param[]) => {
3441
for (const { path, value } of unchanged) {
35-
args.push(Flag.SET_PARAM, [path, value].join('='))
42+
args.push(Flag.SET_PARAM, [path, standardizeValue(value)].join('='))
3643
}
3744

3845
return args

extension/src/test/fixtures/expShow/columns.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ const data: Column[] = [
6767
parentPath: ColumnType.PARAMS,
6868
path: joinColumnPath(ColumnType.PARAMS, 'params.yaml')
6969
},
70+
{
71+
type: ColumnType.PARAMS,
72+
hasChildren: false,
73+
maxStringLength: 3,
74+
name: 'code_names',
75+
parentPath: joinColumnPath(ColumnType.PARAMS, 'params.yaml'),
76+
path: joinColumnPath(ColumnType.PARAMS, 'params.yaml', 'code_names'),
77+
pathArray: [ColumnType.PARAMS, 'params.yaml', 'code_names'],
78+
types: ['array']
79+
},
7080
{
7181
type: ColumnType.PARAMS,
7282
hasChildren: false,

0 commit comments

Comments
 (0)