Skip to content

Commit 484a604

Browse files
authored
Add Custom Plots Section (#3342)
1 parent a09c1e2 commit 484a604

File tree

31 files changed

+1274
-23
lines changed

31 files changed

+1274
-23
lines changed

extension/src/experiments/index.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,10 @@ export class Experiments extends BaseRepository<TableData> {
439439
return this.experiments.getFinishedExperiments()
440440
}
441441

442+
public getExperiments() {
443+
return this.experiments.getExperiments()
444+
}
445+
442446
public getExperimentDisplayName(experimentId: string) {
443447
const experiment = this.experiments
444448
.getCombinedList()
@@ -501,6 +505,10 @@ export class Experiments extends BaseRepository<TableData> {
501505
return this.columns.getFirstThreeColumnOrder()
502506
}
503507

508+
public getColumnTerminalNodes() {
509+
return this.columns.getTerminalNodes()
510+
}
511+
504512
public getHasData() {
505513
if (this.deferred.state === 'none') {
506514
return

extension/src/persistence/constants.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export enum PersistenceKey {
1010
PLOT_COMPARISON_ORDER = 'plotComparisonOrder:',
1111
PLOT_COMPARISON_PATHS_ORDER = 'plotComparisonPathsOrder',
1212
PLOT_METRIC_ORDER = 'plotMetricOrder:',
13+
PLOTS_CUSTOM_ORDER = 'plotCustomOrder:',
1314
PLOT_SECTION_COLLAPSED = 'plotSectionCollapsed:',
1415
PLOT_SELECTED_METRICS = 'plotSelectedMetrics:',
1516
PLOT_SIZES = 'plotSizes:',

extension/src/plots/model/collect.test.ts

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ import {
66
collectCheckpointPlotsData,
77
collectTemplates,
88
collectMetricOrder,
9-
collectOverrideRevisionDetails
9+
collectOverrideRevisionDetails,
10+
collectCustomPlotsData
1011
} from './collect'
1112
import plotsDiffFixture from '../../test/fixtures/plotsDiff/output'
1213
import expShowFixture from '../../test/fixtures/expShow/base/output'
1314
import modifiedFixture from '../../test/fixtures/expShow/modified/output'
1415
import checkpointPlotsFixture from '../../test/fixtures/expShow/base/checkpointPlots'
16+
import customPlotsFixture from '../../test/fixtures/expShow/base/customPlots'
1517
import {
1618
ExperimentsOutput,
1719
ExperimentStatus,
@@ -27,6 +29,62 @@ const logsLossPath = join('logs', 'loss.tsv')
2729

2830
const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot
2931

32+
describe('collectCustomPlotsData', () => {
33+
it('should return the expected data from the text fixture', () => {
34+
const data = collectCustomPlotsData(
35+
[
36+
{
37+
metric: 'metrics:summary.json:loss',
38+
param: 'params:params.yaml:dropout'
39+
},
40+
{
41+
metric: 'metrics:summary.json:accuracy',
42+
param: 'params:params.yaml:epochs'
43+
}
44+
],
45+
[
46+
{
47+
id: '12345',
48+
label: '123',
49+
metrics: {
50+
'summary.json': {
51+
accuracy: 0.4668000042438507,
52+
loss: 2.0205044746398926
53+
}
54+
},
55+
name: 'exp-e7a67',
56+
params: { 'params.yaml': { dropout: 0.15, epochs: 16 } }
57+
},
58+
{
59+
id: '12345',
60+
label: '123',
61+
metrics: {
62+
'summary.json': {
63+
accuracy: 0.3484833240509033,
64+
loss: 1.9293040037155151
65+
}
66+
},
67+
name: 'exp-83425',
68+
params: { 'params.yaml': { dropout: 0.25, epochs: 10 } }
69+
},
70+
{
71+
id: '12345',
72+
label: '123',
73+
metrics: {
74+
'summary.json': {
75+
accuracy: 0.6768440509033,
76+
loss: 2.298503875732422
77+
}
78+
},
79+
name: 'exp-f13bca',
80+
params: { 'params.yaml': { dropout: 0.32, epochs: 20 } }
81+
}
82+
]
83+
)
84+
expect(data).toStrictEqual(customPlotsFixture.plots)
85+
})
86+
})
87+
3088
describe('collectCheckpointPlotsData', () => {
3189
it('should return the expected data from the test fixture', () => {
3290
const data = collectCheckpointPlotsData(expShowFixture)

extension/src/plots/model/collect.ts

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import omit from 'lodash.omit'
2+
import get from 'lodash.get'
23
import { TopLevelSpec } from 'vega-lite'
34
import { VisualizationSpec } from 'react-vega'
5+
import { CustomPlotsOrderValue } from '.'
46
import { getRevisionFirstThreeColumns } from './util'
57
import {
68
ColorScale,
@@ -13,7 +15,8 @@ import {
1315
TemplatePlotEntry,
1416
TemplatePlotSection,
1517
PlotsType,
16-
Revision
18+
Revision,
19+
CustomPlotData
1720
} from '../webview/contract'
1821
import {
1922
EXPERIMENT_WORKSPACE_ID,
@@ -28,9 +31,11 @@ import {
2831
import { extractColumns } from '../../experiments/columns/extract'
2932
import {
3033
decodeColumn,
31-
appendColumnToPath
34+
appendColumnToPath,
35+
splitColumnPath
3236
} from '../../experiments/columns/paths'
3337
import {
38+
ColumnType,
3439
Experiment,
3540
isRunning,
3641
MetricOrParamColumns
@@ -243,6 +248,48 @@ export const collectCheckpointPlotsData = (
243248
return plotsData
244249
}
245250

251+
export const getCustomPlotId = (metric: string, param: string) =>
252+
`custom-${metric}-${param}`
253+
254+
const collectCustomPlotData = (
255+
metric: string,
256+
param: string,
257+
experiments: Experiment[]
258+
): CustomPlotData => {
259+
const splitUpMetricPath = splitColumnPath(metric)
260+
const splitUpParamPath = splitColumnPath(param)
261+
const plotData: CustomPlotData = {
262+
id: getCustomPlotId(metric, param),
263+
metric: metric.slice(ColumnType.METRICS.length + 1),
264+
param: param.slice(ColumnType.PARAMS.length + 1),
265+
values: []
266+
}
267+
268+
for (const experiment of experiments) {
269+
const metricValue = get(experiment, splitUpMetricPath) as number | undefined
270+
const paramValue = get(experiment, splitUpParamPath) as number | undefined
271+
272+
if (metricValue !== undefined && paramValue !== undefined) {
273+
plotData.values.push({
274+
expName: experiment.name || experiment.label,
275+
metric: metricValue,
276+
param: paramValue
277+
})
278+
}
279+
}
280+
281+
return plotData
282+
}
283+
284+
export const collectCustomPlotsData = (
285+
metricsAndParams: CustomPlotsOrderValue[],
286+
experiments: Experiment[]
287+
): CustomPlotData[] => {
288+
return metricsAndParams.map(({ metric, param }) =>
289+
collectCustomPlotData(metric, param, experiments)
290+
)
291+
}
292+
246293
type MetricOrderAccumulator = {
247294
newOrder: string[]
248295
uncollectedMetrics: string[]

extension/src/plots/model/index.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ describe('plotsModel', () => {
103103
const expectedSectionCollapsed = {
104104
[Section.CHECKPOINT_PLOTS]: true,
105105
[Section.TEMPLATE_PLOTS]: false,
106+
[Section.CUSTOM_PLOTS]: false,
106107
[Section.COMPARISON_TABLE]: false
107108
}
108109

extension/src/plots/model/index.ts

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import {
1111
RevisionData,
1212
TemplateAccumulator,
1313
collectCommitRevisionDetails,
14-
collectOverrideRevisionDetails
14+
collectOverrideRevisionDetails,
15+
collectCustomPlotsData,
16+
getCustomPlotId
1517
} from './collect'
1618
import { getRevisionFirstThreeColumns } from './util'
1719
import {
@@ -24,7 +26,8 @@ import {
2426
DEFAULT_SECTION_SIZES,
2527
Section,
2628
SectionCollapsed,
27-
PlotSizeNumber
29+
PlotSizeNumber,
30+
CustomPlotData
2831
} from '../webview/contract'
2932
import {
3033
ExperimentsOutput,
@@ -46,10 +49,13 @@ import {
4649
} from '../multiSource/collect'
4750
import { isDvcError } from '../../cli/dvc/reader'
4851

52+
export type CustomPlotsOrderValue = { metric: string; param: string }
53+
4954
export class PlotsModel extends ModelWithPersistence {
5055
private readonly experiments: Experiments
5156

5257
private plotSizes: Record<Section, number>
58+
private customPlotsOrder: CustomPlotsOrderValue[]
5359
private sectionCollapsed: SectionCollapsed
5460
private commitRevisions: Record<string, string> = {}
5561

@@ -64,6 +70,7 @@ export class PlotsModel extends ModelWithPersistence {
6470
private multiSourceEncoding: MultiSourceEncoding = {}
6571

6672
private checkpointPlots?: CheckpointPlot[]
73+
private customPlots?: CustomPlotData[]
6774
private selectedMetrics?: string[]
6875
private metricOrder: string[]
6976

@@ -89,6 +96,8 @@ export class PlotsModel extends ModelWithPersistence {
8996
undefined
9097
)
9198
this.metricOrder = this.revive(PersistenceKey.PLOT_METRIC_ORDER, [])
99+
100+
this.customPlotsOrder = this.revive(PersistenceKey.PLOTS_CUSTOM_ORDER, [])
92101
}
93102

94103
public transformAndSetExperiments(data: ExperimentsOutput) {
@@ -102,6 +111,8 @@ export class PlotsModel extends ModelWithPersistence {
102111

103112
this.setMetricOrder()
104113

114+
this.recreateCustomPlots()
115+
105116
return this.removeStaleData()
106117
}
107118

@@ -119,6 +130,8 @@ export class PlotsModel extends ModelWithPersistence {
119130
collectMultiSourceVariations(data, this.multiSourceVariations)
120131
])
121132

133+
this.recreateCustomPlots()
134+
122135
this.comparisonData = {
123136
...this.comparisonData,
124137
...comparisonData
@@ -127,7 +140,6 @@ export class PlotsModel extends ModelWithPersistence {
127140
...this.revisionData,
128141
...revisionData
129142
}
130-
131143
this.templates = { ...this.templates, ...templates }
132144
this.multiSourceVariations = multiSourceVariations
133145
this.multiSourceEncoding = collectMultiSourceEncoding(
@@ -171,6 +183,49 @@ export class PlotsModel extends ModelWithPersistence {
171183
}
172184
}
173185

186+
public getCustomPlots() {
187+
if (!this.customPlots) {
188+
return
189+
}
190+
return {
191+
plots: this.customPlots,
192+
size: this.getPlotSize(Section.CUSTOM_PLOTS)
193+
}
194+
}
195+
196+
public recreateCustomPlots() {
197+
const customPlots: CustomPlotData[] = collectCustomPlotsData(
198+
this.getCustomPlotsOrder(),
199+
this.experiments.getExperiments()
200+
)
201+
this.customPlots = customPlots
202+
}
203+
204+
public getCustomPlotsOrder() {
205+
return this.customPlotsOrder
206+
}
207+
208+
public setCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) {
209+
this.customPlotsOrder = plotsOrder
210+
this.persist(PersistenceKey.PLOTS_CUSTOM_ORDER, this.customPlotsOrder)
211+
this.recreateCustomPlots()
212+
}
213+
214+
public removeCustomPlots(plotIds: string[]) {
215+
const newCustomPlotsOrder = this.getCustomPlotsOrder().filter(
216+
({ metric, param }) => {
217+
return !plotIds.includes(getCustomPlotId(metric, param))
218+
}
219+
)
220+
221+
this.setCustomPlotsOrder(newCustomPlotsOrder)
222+
}
223+
224+
public addCustomPlot(metricAndParam: CustomPlotsOrderValue) {
225+
const newCustomPlotsOrder = [...this.getCustomPlotsOrder(), metricAndParam]
226+
this.setCustomPlotsOrder(newCustomPlotsOrder)
227+
}
228+
174229
public setupManualRefresh(id: string) {
175230
this.deleteRevisionData(id)
176231
}

0 commit comments

Comments
 (0)