Skip to content

Commit 1310d66

Browse files
authored
Enable running exp apply and exp branch against commits (#3834)
1 parent ea310a5 commit 1310d66

File tree

10 files changed

+100
-25
lines changed

10 files changed

+100
-25
lines changed

extension/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,12 +1163,12 @@
11631163
{
11641164
"command": "dvc.views.experiments.applyExperiment",
11651165
"group": "inline@1",
1166-
"when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem == experiment && !dvc.experiment.running.workspace"
1166+
"when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem =~ /^(experiment|commit)$/ && !dvc.experiment.running.workspace"
11671167
},
11681168
{
11691169
"command": "dvc.views.experiments.branchExperiment",
11701170
"group": "inline@2",
1171-
"when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem == experiment && !dvc.experiment.running.workspace"
1171+
"when": "view == dvc.views.experimentsTree && dvc.commands.available && viewItem =~ /^(experiment|commit)$/ && !dvc.experiment.running.workspace"
11721172
},
11731173
{
11741174
"command": "dvc.views.experimentsTree.removeExperiment",

extension/src/experiments/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,9 @@ export class Experiments extends BaseRepository<TableData> {
371371
return this.notifyChanged()
372372
}
373373

374-
public pickExperiment() {
374+
public pickCommitOrExperiment() {
375375
return pickExperiment(
376-
this.experiments.getExperiments(),
376+
this.experiments.getCommitsAndExperiments(),
377377
this.getFirstThreeColumnOrder()
378378
)
379379
}

extension/src/experiments/model/collect.ts

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ import {
66
} from 'vscode'
77
import { ExperimentType } from '.'
88
import { extractColumns } from '../columns/extract'
9-
import { Experiment, CommitData, RunningExperiment } from '../webview/contract'
9+
import {
10+
Experiment,
11+
CommitData,
12+
RunningExperiment,
13+
isQueued
14+
} from '../webview/contract'
1015
import {
1116
EXPERIMENT_WORKSPACE_ID,
1217
ExperimentStatus,
@@ -372,3 +377,28 @@ export const collectExperimentType = (
372377

373378
return acc
374379
}
380+
381+
const collectExperimentsAndCommit = (
382+
acc: Experiment[],
383+
commit: Experiment,
384+
experiments: Experiment[] = []
385+
): void => {
386+
acc.push(commit)
387+
for (const experiment of experiments) {
388+
if (isQueued(experiment.status)) {
389+
continue
390+
}
391+
acc.push(experiment)
392+
}
393+
}
394+
395+
export const collectOrderedCommitsAndExperiments = (
396+
commits: Experiment[],
397+
getExperimentsByCommit: (commit: Experiment) => Experiment[] | undefined
398+
): Experiment[] => {
399+
const acc: Experiment[] = []
400+
for (const commit of commits) {
401+
collectExperimentsAndCommit(acc, commit, getExperimentsByCommit(commit))
402+
}
403+
return acc
404+
}

extension/src/experiments/model/index.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import { Memento } from 'vscode'
22
import { SortDefinition, sortExperiments } from './sortBy'
33
import { FilterDefinition, filterExperiment, getFilterId } from './filterBy'
4-
import { collectExperiments } from './collect'
4+
import {
5+
collectExperiments,
6+
collectOrderedCommitsAndExperiments
7+
} from './collect'
58
import {
69
collectColoredStatus,
710
collectFinishedRunningExperiments,
@@ -319,6 +322,12 @@ export class ExperimentsModel extends ModelWithPersistence {
319322
})
320323
}
321324

325+
public getCommitsAndExperiments() {
326+
return collectOrderedCommitsAndExperiments(this.commits, commit =>
327+
this.getExperimentsByCommit(commit)
328+
)
329+
}
330+
322331
public getExperimentsAndQueued() {
323332
return flattenMapValues(this.experimentsByCommit).map(experiment =>
324333
this.addDetails(experiment)

extension/src/experiments/workspace.test.ts

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const mockedQuickPickOne = jest.mocked(quickPickOne)
3333
const mockedQuickPickManyValues = jest.mocked(quickPickManyValues)
3434
const mockedQuickPickOneOrInput = jest.mocked(quickPickOneOrInput)
3535
const mockedGetValidInput = jest.mocked(getValidInput)
36-
const mockedPickExperiment = jest.fn()
36+
const mockedPickCommitOrExperiment = jest.fn()
3737
const mockedGetInput = jest.mocked(getInput)
3838
const mockedRun = jest.fn()
3939
const mockedExpFunc = jest.fn()
@@ -91,7 +91,7 @@ describe('Experiments', () => {
9191
{
9292
'/my/dvc/root': {
9393
getDvcRoot: () => mockedDvcRoot,
94-
pickExperiment: mockedPickExperiment,
94+
pickCommitOrExperiment: mockedPickCommitOrExperiment,
9595
showWebview: mockedShowWebview
9696
} as unknown as Experiments,
9797
'/my/fun/dvc/root': {
@@ -138,12 +138,12 @@ describe('Experiments', () => {
138138
it('should call the correct function with the correct parameters if a project and experiment are picked', async () => {
139139
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
140140
mockedListStages.mockResolvedValueOnce('train')
141-
mockedPickExperiment.mockResolvedValueOnce('a123456')
141+
mockedPickCommitOrExperiment.mockResolvedValueOnce('a123456')
142142

143143
await workspaceExperiments.getCwdAndExpNameThenRun(mockedCommandId)
144144

145145
expect(mockedQuickPickOne).toHaveBeenCalledTimes(1)
146-
expect(mockedPickExperiment).toHaveBeenCalledTimes(1)
146+
expect(mockedPickCommitOrExperiment).toHaveBeenCalledTimes(1)
147147
expect(mockedExpFunc).toHaveBeenCalledTimes(1)
148148
expect(mockedExpFunc).toHaveBeenCalledWith(mockedDvcRoot, 'a123456')
149149
})
@@ -240,7 +240,7 @@ describe('Experiments', () => {
240240
it('should call the correct function with the correct parameters if a project and experiment are picked and an input provided', async () => {
241241
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
242242
mockedListStages.mockResolvedValueOnce('train')
243-
mockedPickExperiment.mockResolvedValueOnce('a123456')
243+
mockedPickCommitOrExperiment.mockResolvedValueOnce('a123456')
244244
mockedGetInput.mockResolvedValueOnce('abc123')
245245

246246
await workspaceExperiments.getCwdExpNameAndInputThenRun(
@@ -250,7 +250,7 @@ describe('Experiments', () => {
250250
)
251251

252252
expect(mockedQuickPickOne).toHaveBeenCalledTimes(1)
253-
expect(mockedPickExperiment).toHaveBeenCalledTimes(1)
253+
expect(mockedPickCommitOrExperiment).toHaveBeenCalledTimes(1)
254254
expect(mockedExpFunc).toHaveBeenCalledTimes(1)
255255
expect(mockedExpFunc).toHaveBeenCalledWith(
256256
mockedDvcRoot,
@@ -276,7 +276,7 @@ describe('Experiments', () => {
276276
it('should not call the function if user input is not provided', async () => {
277277
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
278278
mockedListStages.mockResolvedValueOnce('train')
279-
mockedPickExperiment.mockResolvedValueOnce({
279+
mockedPickCommitOrExperiment.mockResolvedValueOnce({
280280
id: 'b456789',
281281
name: 'exp-456'
282282
})
@@ -296,7 +296,7 @@ describe('Experiments', () => {
296296
it('should check and ask for the creation of a pipeline stage before running the command', async () => {
297297
mockedQuickPickOne.mockResolvedValueOnce(mockedDvcRoot)
298298
mockedListStages.mockResolvedValueOnce('')
299-
mockedPickExperiment.mockResolvedValueOnce({
299+
mockedPickCommitOrExperiment.mockResolvedValueOnce({
300300
id: 'a123456',
301301
name: 'exp-123'
302302
})

extension/src/experiments/workspace.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
210210
}
211211

212212
public getCwdAndExpNameThenRun(commandId: CommandId) {
213-
return this.pickExpThenRun(commandId, cwd => this.pickExperiment(cwd))
213+
return this.pickExpThenRun(commandId, cwd =>
214+
this.pickCommitOrExperiment(cwd)
215+
)
214216
}
215217

216218
public async getCwdAndQuickPickThenRun(
@@ -237,7 +239,7 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
237239
return
238240
}
239241

240-
const experimentId = await this.pickExperiment(cwd)
242+
const experimentId = await this.pickCommitOrExperiment(cwd)
241243

242244
if (!experimentId) {
243245
return
@@ -545,7 +547,7 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
545547
return this.runCommand(commandId, cwd, experimentId)
546548
}
547549

548-
private pickExperiment(cwd: string) {
549-
return this.getRepository(cwd).pickExperiment()
550+
private pickCommitOrExperiment(cwd: string) {
551+
return this.getRepository(cwd).pickCommitOrExperiment()
550552
}
551553
}

extension/src/test/suite/experiments/workspace.test.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ suite('Workspace Experiments Test Suite', () => {
542542
})
543543

544544
describe('dvc.applyExperiment', () => {
545-
it('should ask the user to pick an experiment and then apply that experiment to the workspace', async () => {
545+
it('should ask the user to pick a commit or experiment and then apply it to the workspace', async () => {
546546
const selectedExperiment = 'test-branch'
547547

548548
const { experiments } = buildExperiments(disposable)
@@ -562,8 +562,17 @@ suite('Workspace Experiments Test Suite', () => {
562562
dvcDemoPath,
563563
selectedExperiment
564564
)
565+
565566
expect(mockShowQuickPick).to.be.calledWith(
566567
[
568+
{
569+
description: undefined,
570+
detail: `Created:${formatDate(
571+
'2020-11-21T19:58:22'
572+
)}, loss:2.0488560, accuracy:0.34848332`,
573+
label: 'main',
574+
value: 'main'
575+
},
567576
{
568577
description: '[exp-e7a67]',
569578
detail: `Created:${formatDate(

webview/src/experiments/components/App.test.tsx

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ describe('App', () => {
858858
expect(itemLabels).toStrictEqual(['Modify and Run', 'Modify and Queue'])
859859
})
860860

861-
it('should enable the correct options for the main row with checkpoints', () => {
861+
it('should enable the correct options for a commit with checkpoints', () => {
862862
renderTableWithoutRunningExperiments()
863863

864864
const target = screen.getByText('main')
@@ -870,13 +870,35 @@ describe('App', () => {
870870
.filter(item => !item.className.includes('disabled'))
871871
.map(item => item.textContent)
872872
expect(itemLabels).toStrictEqual([
873+
'Apply to Workspace',
874+
'Create new Branch',
873875
'Modify and Run',
874876
'Modify and Resume',
875877
'Modify and Queue',
876878
'Star'
877879
])
878880
})
879881

882+
it('should enable the correct options for a commit without checkpoints', () => {
883+
renderTableWithoutRunningExperiments(false)
884+
885+
const target = screen.getByText('main')
886+
fireEvent.contextMenu(target, { bubbles: true })
887+
888+
advanceTimersByTime(100)
889+
const menuitems = screen.getAllByRole('menuitem')
890+
const itemLabels = menuitems
891+
.filter(item => !item.className.includes('disabled'))
892+
.map(item => item.textContent)
893+
expect(itemLabels).toStrictEqual([
894+
'Apply to Workspace',
895+
'Create new Branch',
896+
'Modify and Run',
897+
'Modify and Queue',
898+
'Star'
899+
])
900+
})
901+
880902
it('should enable the correct options for an experiment that is not running and close on esc', () => {
881903
renderTableWithoutRunningExperiments()
882904

webview/src/experiments/components/table/body/RowContextMenu.tsx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,11 @@ const getSingleSelectMenuOptions = (
229229
divider
230230
)
231231

232-
const disableIfRunningOrNotExperiment = (
232+
const disableIfRunningOrWorkspace = (
233233
label: string,
234234
type: MessageFromWebviewType,
235235
divider?: boolean
236-
) => disableIfRunning(label, type, isNotExperiment, divider)
236+
) => disableIfRunning(label, type, isWorkspace, divider)
237237

238238
return [
239239
experimentMenuOption(
@@ -242,11 +242,11 @@ const getSingleSelectMenuOptions = (
242242
MessageFromWebviewType.SHOW_EXPERIMENT_LOGS,
243243
!isRunningInQueue({ executor, status })
244244
),
245-
disableIfRunningOrNotExperiment(
245+
disableIfRunningOrWorkspace(
246246
'Apply to Workspace',
247247
MessageFromWebviewType.APPLY_EXPERIMENT_TO_WORKSPACE
248248
),
249-
disableIfRunningOrNotExperiment(
249+
disableIfRunningOrWorkspace(
250250
'Create new Branch',
251251
MessageFromWebviewType.CREATE_BRANCH_FROM_EXPERIMENT
252252
),

webview/src/test/experimentsTable.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,12 @@ export const renderTableWithSortingData = () => {
6060
return renderTable(sortingTableDataFixture)
6161
}
6262

63-
export const renderTableWithoutRunningExperiments = () => {
63+
export const renderTableWithoutRunningExperiments = (
64+
hasCheckpoints?: boolean
65+
) => {
6466
renderTable({
6567
...tableDataFixture,
68+
hasCheckpoints: hasCheckpoints ?? tableDataFixture.hasCheckpoints,
6669
hasRunningWorkspaceExperiment: false,
6770
rows: tableDataFixture.rows.map(row => ({
6871
...row,

0 commit comments

Comments
 (0)