Skip to content

Commit 63d30fe

Browse files
authored
Ensure experiment summary info (columns) is always available in the experiment table data (#4396)
1 parent 92804d9 commit 63d30fe

File tree

8 files changed

+231
-159
lines changed

8 files changed

+231
-159
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { Column, ColumnType } from '../../webview/contract'
2+
import { EXPERIMENT_COLUMN_ID } from '../constants'
3+
4+
export const collectColumnOrder = async (
5+
existingColumnOrder: string[],
6+
terminalNodes: Column[]
7+
): Promise<string[]> => {
8+
const acc: { [columnType: string]: string[] } = {
9+
[ColumnType.DEPS]: [],
10+
[ColumnType.METRICS]: [],
11+
[ColumnType.PARAMS]: [],
12+
[ColumnType.TIMESTAMP]: []
13+
}
14+
for (const { type, path } of terminalNodes) {
15+
if (existingColumnOrder.includes(path)) {
16+
continue
17+
}
18+
acc[type].push(path)
19+
}
20+
21+
// eslint-disable-next-line etc/no-assign-mutated-array
22+
await Promise.all([acc.metrics.sort(), acc.params.sort(), acc.deps.sort()])
23+
24+
if (!existingColumnOrder.includes(EXPERIMENT_COLUMN_ID)) {
25+
existingColumnOrder.unshift(EXPERIMENT_COLUMN_ID)
26+
}
27+
28+
return [
29+
...existingColumnOrder,
30+
...acc.timestamp,
31+
...acc.metrics,
32+
...acc.params,
33+
...acc.deps
34+
]
35+
}

extension/src/experiments/columns/model.test.ts

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import { join } from 'path'
12
import { Disposable, Disposer } from '@hediet/std/disposable'
23
import { ColumnsModel } from './model'
34
import { appendColumnToPath, buildMetricOrParamPath } from './paths'
4-
import { timestampColumn } from './constants'
5+
import { EXPERIMENT_COLUMN_ID, timestampColumn } from './constants'
56
import { buildMockMemento } from '../../test/util'
67
import { generateTestExpShowOutput } from '../../test/util/experiments'
78
import { Status } from '../../path/selection/model'
@@ -258,7 +259,7 @@ describe('ColumnsModel', () => {
258259
expect(model.getColumnOrder()).toStrictEqual(persistedState)
259260
})
260261

261-
it('should return the first three visible columns for both metrics and params from the persisted state', async () => {
262+
it('should return the first three visible columns for both metrics and params from the persisted state (first)', async () => {
262263
const persistedState = [
263264
'id',
264265
'Created',
@@ -287,7 +288,8 @@ describe('ColumnsModel', () => {
287288
'params:params.yaml:process.threshold',
288289
'params:params.yaml:process.test_arg',
289290
'metrics:summary.json:loss',
290-
'metrics:summary.json:accuracy'
291+
'metrics:summary.json:accuracy',
292+
'metrics:summary.json:val_accuracy'
291293
])
292294

293295
model.toggleStatus('params:params.yaml:dvc_logs_dir')
@@ -297,10 +299,65 @@ describe('ColumnsModel', () => {
297299
'params:params.yaml:process.test_arg',
298300
'params:params.yaml:dropout',
299301
'metrics:summary.json:loss',
300-
'metrics:summary.json:accuracy'
302+
'metrics:summary.json:accuracy',
303+
'metrics:summary.json:val_accuracy'
304+
])
305+
})
306+
307+
it('should not add a param that is no longer present to the summary column order', async () => {
308+
const persistedState = [
309+
'id',
310+
'Created',
311+
'params:params.yaml:an-old-params'
312+
]
313+
314+
const model = new ColumnsModel(
315+
exampleDvcRoot,
316+
buildMockMemento({
317+
[PersistenceKey.METRICS_AND_PARAMS_COLUMN_ORDER + exampleDvcRoot]:
318+
persistedState
319+
}),
320+
mockedColumnsOrderOrStatusChanged
321+
)
322+
model.toggleStatus('params:params.yaml:an-old-params')
323+
await model.transformAndSet(outputFixture)
324+
325+
expect(model.getSummaryColumnOrder()).toStrictEqual([
326+
join('params:nested', 'params.yaml:test'),
327+
'params:params.yaml:code_names',
328+
'params:params.yaml:dropout',
329+
'metrics:summary.json:accuracy',
330+
'metrics:summary.json:loss',
331+
'metrics:summary.json:val_accuracy'
301332
])
302333
})
303334

335+
it('should add to the persisted state when there are columns that were not found', async () => {
336+
const persistedState = ['params:params.yaml:dvc_logs_dir']
337+
338+
const model = new ColumnsModel(
339+
exampleDvcRoot,
340+
buildMockMemento({
341+
[PersistenceKey.METRICS_AND_PARAMS_COLUMN_ORDER + exampleDvcRoot]:
342+
persistedState
343+
}),
344+
mockedColumnsOrderOrStatusChanged
345+
)
346+
await model.transformAndSet(outputFixture)
347+
348+
expect(model.getSummaryColumnOrder()).toStrictEqual([
349+
'params:params.yaml:dvc_logs_dir',
350+
join('params:nested', 'params.yaml:test'),
351+
'params:params.yaml:code_names',
352+
'metrics:summary.json:accuracy',
353+
'metrics:summary.json:loss',
354+
'metrics:summary.json:val_accuracy'
355+
])
356+
357+
const [id] = model.getColumnOrder()
358+
expect(id).toStrictEqual(EXPERIMENT_COLUMN_ID)
359+
})
360+
304361
it('should return the first three metric and param columns (none hidden) collected from data if state is empty', async () => {
305362
const model = new ColumnsModel(
306363
exampleDvcRoot,
@@ -310,23 +367,23 @@ describe('ColumnsModel', () => {
310367
await model.transformAndSet(outputFixture)
311368

312369
expect(model.getSummaryColumnOrder()).toStrictEqual([
370+
join('params:nested', 'params.yaml:test'),
313371
'params:params.yaml:code_names',
314-
'params:params.yaml:epochs',
315-
'params:params.yaml:learning_rate',
316-
'metrics:summary.json:loss',
372+
'params:params.yaml:dropout',
317373
'metrics:summary.json:accuracy',
318-
'metrics:summary.json:val_loss'
374+
'metrics:summary.json:loss',
375+
'metrics:summary.json:val_accuracy'
319376
])
320377

321378
model.toggleStatus('params:params.yaml:code_names')
322379

323380
expect(model.getSummaryColumnOrder()).toStrictEqual([
324-
'params:params.yaml:epochs',
325-
'params:params.yaml:learning_rate',
381+
join('params:nested', 'params.yaml:test'),
382+
'params:params.yaml:dropout',
326383
'params:params.yaml:dvc_logs_dir',
327-
'metrics:summary.json:loss',
328384
'metrics:summary.json:accuracy',
329-
'metrics:summary.json:val_loss'
385+
'metrics:summary.json:loss',
386+
'metrics:summary.json:val_accuracy'
330387
])
331388
})
332389

extension/src/experiments/columns/model.ts

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import {
55
collectRelativeMetricsFiles,
66
collectParamsFiles
77
} from './collect'
8-
import { EXPERIMENT_COLUMN_ID, timestampColumn } from './constants'
98
import {
109
MAX_SUMMARY_ORDER_LENGTH,
1110
SummaryAcc,
1211
collectFromColumnOrder as collectSummaryColumnOrder,
1312
limitSummaryOrder
1413
} from './util'
14+
import { collectColumnOrder } from './collect/order'
1515
import { Column, ColumnType } from '../webview/contract'
1616
import { ExpShowOutput } from '../../cli/dvc/contract'
1717
import { PersistenceKey } from '../../persistence/constants'
@@ -158,38 +158,13 @@ export class ColumnsModel extends PathSelectionModel<Column> {
158158
)
159159
}
160160

161-
private findChildrenColumns(
162-
parent: string,
163-
columns: Column[],
164-
childrenColumns: string[]
165-
) {
166-
const filteredColumns = columns.filter(
167-
({ parentPath }) => parentPath === parent
161+
private async setColumnOrderFromData(terminalNodes: Column[]) {
162+
const extendedColumnOrder = await collectColumnOrder(
163+
this.columnOrderState,
164+
terminalNodes
168165
)
169-
for (const column of filteredColumns) {
170-
if (column.hasChildren) {
171-
this.findChildrenColumns(column.path, columns, childrenColumns)
172-
} else {
173-
childrenColumns.push(column.path)
174-
}
175-
}
176-
}
177166

178-
private getColumnsFromType(type: string): string[] {
179-
const childrenColumns: string[] = []
180-
const dataWithType = this.data.filter(({ path }) => path.startsWith(type))
181-
this.findChildrenColumns(type, dataWithType, childrenColumns)
182-
return childrenColumns
183-
}
184-
185-
private getColumnOrderFromData() {
186-
return [
187-
EXPERIMENT_COLUMN_ID,
188-
timestampColumn.path,
189-
...this.getColumnsFromType(ColumnType.METRICS),
190-
...this.getColumnsFromType(ColumnType.PARAMS),
191-
...this.getColumnsFromType(ColumnType.DEPS)
192-
]
167+
this.setColumnOrder(extendedColumnOrder)
193168
}
194169

195170
private async transformAndSetColumns(data: ExpShowOutput) {
@@ -203,12 +178,18 @@ export class ColumnsModel extends PathSelectionModel<Column> {
203178

204179
this.data = columns
205180

206-
if (this.columnOrderState.length === 0) {
207-
this.setColumnOrder(this.getColumnOrderFromData())
208-
}
209-
210181
this.paramsFiles = paramsFiles
211182
this.relativeMetricsFiles = relativeMetricsFiles
183+
184+
const selectedColumns = this.getTerminalNodes().filter(
185+
({ selected }) => selected
186+
)
187+
188+
for (const { path } of selectedColumns) {
189+
if (!this.columnOrderState.includes(path)) {
190+
return this.setColumnOrderFromData(selectedColumns)
191+
}
192+
}
212193
}
213194

214195
private transformAndSetChanges(data: ExpShowOutput) {

extension/src/path/selection/model.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,22 @@ export abstract class PathSelectionModel<
6767
}
6868

6969
protected setNewStatuses(data: { path: string }[]) {
70+
const paths = new Set<string>()
7071
for (const { path } of data) {
7172
if (this.status[path] === undefined) {
7273
this.status[path] = Status.SELECTED
7374
}
75+
paths.add(path)
76+
}
77+
78+
this.removeMissingSelected(paths)
79+
}
80+
81+
private removeMissingSelected(paths: Set<string>) {
82+
for (const [path, status] of Object.entries(this.status)) {
83+
if (!paths.has(path) && status === Status.SELECTED) {
84+
delete this.status[path]
85+
}
7486
}
7587
}
7688

extension/src/test/fixtures/expShow/base/columns.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,27 @@ const nestedParamsFile = join('nested', 'params.yaml')
1111
export const dataColumnOrder: string[] = [
1212
'id',
1313
'Created',
14-
'metrics:summary.json:loss',
1514
'metrics:summary.json:accuracy',
16-
'metrics:summary.json:val_loss',
15+
'metrics:summary.json:loss',
1716
'metrics:summary.json:val_accuracy',
17+
'metrics:summary.json:val_loss',
18+
join('params:nested', 'params.yaml:test'),
1819
'params:params.yaml:code_names',
20+
'params:params.yaml:dropout',
21+
'params:params.yaml:dvc_logs_dir',
1922
'params:params.yaml:epochs',
2023
'params:params.yaml:learning_rate',
21-
'params:params.yaml:dvc_logs_dir',
2224
'params:params.yaml:log_file',
23-
'params:params.yaml:dropout',
24-
'params:params.yaml:process.threshold',
2525
'params:params.yaml:process.test_arg',
26-
join('params:nested', 'params.yaml:test'),
26+
'params:params.yaml:process.threshold',
2727
join('deps:data', 'data.xml'),
28-
join('deps:data', 'prepared'),
2928
join('deps:data', 'features'),
30-
join('deps:src', 'prepare.py'),
31-
join('deps:src', 'featurization.py'),
32-
join('deps:src', 'train.py'),
29+
join('deps:data', 'prepared'),
30+
'deps:model.pkl',
3331
join('deps:src', 'evaluate.py'),
34-
'deps:model.pkl'
32+
join('deps:src', 'featurization.py'),
33+
join('deps:src', 'prepare.py'),
34+
join('deps:src', 'train.py')
3535
]
3636

3737
const data: Column[] = [

0 commit comments

Comments
 (0)