Skip to content

Commit a1eed29

Browse files
authored
Add ability to stop experiments running in the workspace (outside of the extension) (#3247)
1 parent 4525e01 commit a1eed29

File tree

12 files changed

+133
-164
lines changed

12 files changed

+133
-164
lines changed

extension/package.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@
820820
},
821821
{
822822
"command": "dvc.stopAllRunningExperiments",
823-
"when": "dvc.commands.available && dvc.project.available && dvc.experiment.stoppable"
823+
"when": "dvc.commands.available && dvc.project.available && dvc.experiment.running"
824824
},
825825
{
826826
"command": "dvc.views.experimentsColumnsTree.selectColumns",
@@ -993,7 +993,7 @@
993993
{
994994
"command": "dvc.stopAllRunningExperiments",
995995
"group": "navigation@0",
996-
"when": "dvc.experiment.stoppable && dvc.commands.available"
996+
"when": "dvc.experiment.running && dvc.commands.available"
997997
},
998998
{
999999
"command": "dvc.resetAndRunCheckpointExperiment",
@@ -1186,12 +1186,12 @@
11861186
},
11871187
{
11881188
"command": "dvc.stopAllRunningExperiments",
1189-
"when": "view == dvc.views.experimentsTree && !dvc.experiments.webview.active && dvc.experiment.stoppable",
1189+
"when": "view == dvc.views.experimentsTree && !dvc.experiments.webview.active && dvc.experiment.running",
11901190
"group": "1_run@1"
11911191
},
11921192
{
11931193
"command": "dvc.stopAllRunningExperiments",
1194-
"when": "view == dvc.views.experimentsTree && dvc.experiments.webview.active && dvc.experiment.stoppable",
1194+
"when": "view == dvc.views.experimentsTree && dvc.experiments.webview.active && dvc.experiment.running",
11951195
"group": "navigation@1"
11961196
},
11971197
{

extension/src/cli/dvc/constants.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ export const UNEXPECTED_ERROR_CODE = 255
44
export const DOT_DVC = '.dvc'
55

66
export const TEMP_PLOTS_DIR = join(DOT_DVC, 'tmp', 'plots')
7+
8+
const TEMP_EXP_DIR = join(DOT_DVC, 'tmp', 'exps')
79
export const DVCLIVE_ONLY_RUNNING_SIGNAL_FILE = join(
8-
DOT_DVC,
9-
'tmp',
10-
'exps',
10+
TEMP_EXP_DIR,
1111
'run',
1212
'DVCLIVE_ONLY'
1313
)
14+
export const EXP_RWLOCK_FILE = join(TEMP_EXP_DIR, 'rwlock.lock')
1415

1516
export const NUM_OF_COMMITS_TO_SHOW = '3'
1617

extension/src/context.ts

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export class Context extends Disposable {
3737
repositories.push(this.experiments.getRepository(dvcRoot))
3838
}
3939

40-
void this.setIsExperimentRunning(repositories)
40+
this.setIsExperimentRunning(repositories)
4141

4242
void setContextValue(
4343
ContextKey.EXPERIMENTS_FILTERED,
@@ -52,24 +52,11 @@ export class Context extends Disposable {
5252
)
5353
}
5454

55-
private async setIsExperimentRunning(repositories: Experiments[] = []) {
56-
if (
57-
this.dvcRunner.isExperimentRunning() ||
58-
repositories.some(experiments => experiments.hasRunningQueuedExperiment())
59-
) {
60-
void setContextValue(ContextKey.EXPERIMENT_RUNNING, true)
61-
void setContextValue(ContextKey.EXPERIMENT_STOPPABLE, true)
62-
return
63-
}
64-
55+
private setIsExperimentRunning(repositories: Experiments[] = []) {
6556
void setContextValue(
6657
ContextKey.EXPERIMENT_RUNNING,
67-
repositories.some(experiments => experiments.hasRunningExperiment())
68-
)
69-
70-
void setContextValue(
71-
ContextKey.EXPERIMENT_STOPPABLE,
72-
await this.experiments.hasDvcLiveOnlyExperimentRunning()
58+
this.dvcRunner.isExperimentRunning() ||
59+
repositories.some(experiments => experiments.hasRunningExperiment())
7360
)
7461
}
7562
}

extension/src/experiments/index.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,6 @@ export class Experiments extends BaseRepository<TableData> {
521521
return this.experiments.hasRunningExperiment()
522522
}
523523

524-
public hasRunningQueuedExperiment() {
525-
return this.experiments.getRunningQueueTasks().length > 0
526-
}
527-
528524
public getFirstThreeColumnOrder() {
529525
return this.columns.getFirstThreeColumnOrder()
530526
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import { join } from 'path'
2+
import { collectRunningExperimentPids } from './collect'
3+
import { getPidFromFile } from '../../fileSystem'
4+
import {
5+
DVCLIVE_ONLY_RUNNING_SIGNAL_FILE,
6+
EXP_RWLOCK_FILE
7+
} from '../../cli/dvc/constants'
8+
9+
jest.mock('../../fileSystem')
10+
11+
const mockedGetPidFromFile = jest.mocked(getPidFromFile)
12+
13+
beforeEach(() => {
14+
jest.resetAllMocks()
15+
})
16+
17+
describe('collectRunningExperimentPids', () => {
18+
it('should exclude undefined from the final result', async () => {
19+
mockedGetPidFromFile
20+
.mockResolvedValueOnce(undefined)
21+
.mockResolvedValueOnce(1234)
22+
23+
const mockedDvcRoot = join('mock', 'root')
24+
25+
expect(await collectRunningExperimentPids([mockedDvcRoot])).toStrictEqual([
26+
1234
27+
])
28+
29+
expect(mockedGetPidFromFile).toHaveBeenCalledTimes(2)
30+
expect(mockedGetPidFromFile).toHaveBeenCalledWith(
31+
join(mockedDvcRoot, DVCLIVE_ONLY_RUNNING_SIGNAL_FILE)
32+
)
33+
expect(mockedGetPidFromFile).toHaveBeenCalledWith(
34+
join(mockedDvcRoot, EXP_RWLOCK_FILE)
35+
)
36+
})
37+
38+
it("should collect the pid of processes which are located in each repository's files", async () => {
39+
mockedGetPidFromFile
40+
.mockResolvedValueOnce(1)
41+
.mockResolvedValueOnce(2)
42+
.mockResolvedValueOnce(3)
43+
.mockResolvedValueOnce(4)
44+
45+
const mockedFirstDvcRoot = join('mock', 'root', '1')
46+
const mockedSecondDvcRoot = join('mock', 'root', '2')
47+
48+
expect(
49+
await collectRunningExperimentPids([
50+
mockedFirstDvcRoot,
51+
mockedSecondDvcRoot
52+
])
53+
).toStrictEqual([1, 2, 3, 4])
54+
55+
expect(mockedGetPidFromFile).toHaveBeenCalledTimes(4)
56+
expect(mockedGetPidFromFile).toHaveBeenCalledWith(
57+
join(mockedFirstDvcRoot, DVCLIVE_ONLY_RUNNING_SIGNAL_FILE)
58+
)
59+
expect(mockedGetPidFromFile).toHaveBeenCalledWith(
60+
join(mockedFirstDvcRoot, EXP_RWLOCK_FILE)
61+
)
62+
expect(mockedGetPidFromFile).toHaveBeenCalledWith(
63+
join(mockedSecondDvcRoot, DVCLIVE_ONLY_RUNNING_SIGNAL_FILE)
64+
)
65+
expect(mockedGetPidFromFile).toHaveBeenCalledWith(
66+
join(mockedSecondDvcRoot, EXP_RWLOCK_FILE)
67+
)
68+
})
69+
})
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import { join } from 'path'
2+
import {
3+
DVCLIVE_ONLY_RUNNING_SIGNAL_FILE,
4+
EXP_RWLOCK_FILE
5+
} from '../../cli/dvc/constants'
6+
import { getPidFromFile } from '../../fileSystem'
7+
8+
const collectDvcRootPids = async (
9+
acc: Set<number>,
10+
dvcRoot: string
11+
): Promise<void> => {
12+
for (const file of [
13+
join(dvcRoot, DVCLIVE_ONLY_RUNNING_SIGNAL_FILE),
14+
join(dvcRoot, EXP_RWLOCK_FILE)
15+
]) {
16+
const pid = await getPidFromFile(file)
17+
if (!pid) {
18+
continue
19+
}
20+
acc.add(pid)
21+
}
22+
}
23+
24+
export const collectRunningExperimentPids = async (
25+
dvcRoots: string[]
26+
): Promise<number[]> => {
27+
const acc = new Set<number>()
28+
for (const dvcRoot of dvcRoots) {
29+
await collectDvcRootPids(acc, dvcRoot)
30+
}
31+
return [...acc]
32+
}

extension/src/experiments/workspace.ts

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import { join } from 'path'
21
import { EventEmitter, Memento } from 'vscode'
32
import isEmpty from 'lodash.isempty'
43
import { Experiments, ModifiedExperimentAndRunCommandId } from '.'
54
import { TableData } from './webview/contract'
6-
import { Args, DVCLIVE_ONLY_RUNNING_SIGNAL_FILE } from '../cli/dvc/constants'
5+
import { Args } from '../cli/dvc/constants'
76
import {
87
AvailableCommands,
98
CommandId,
@@ -15,8 +14,7 @@ import { getInput, getPositiveIntegerInput } from '../vscode/inputBox'
1514
import { BaseWorkspaceWebviews } from '../webview/workspace'
1615
import { Title } from '../vscode/title'
1716
import { ContextKey, setContextValue } from '../vscode/context'
18-
import { findOrCreateDvcYamlFile, getPidFromSignalFile } from '../fileSystem'
19-
import { definedAndNonEmpty } from '../util/array'
17+
import { findOrCreateDvcYamlFile } from '../fileSystem'
2018
import { quickPickOneOrInput } from '../vscode/quickPick'
2119
import { pickFile } from '../vscode/resourcePicker'
2220

@@ -415,28 +413,9 @@ export class WorkspaceExperiments extends BaseWorkspaceWebviews<
415413
return allLoading
416414
}
417415

418-
public async hasDvcLiveOnlyExperimentRunning() {
419-
return definedAndNonEmpty(await this.getDvcLiveOnlyPids())
420-
}
421-
422-
public async getDvcLiveOnlyPids() {
423-
const pids: number[] = []
424-
425-
for (const dvcRoot of this.getDvcRoots()) {
426-
const signalFile = join(dvcRoot, DVCLIVE_ONLY_RUNNING_SIGNAL_FILE)
427-
const pid = await getPidFromSignalFile(signalFile)
428-
if (!pid) {
429-
continue
430-
}
431-
pids.push(pid)
432-
}
433-
434-
return pids
435-
}
436-
437-
public hasQueuedExperimentsRunning() {
416+
public hasRunningExperiment() {
438417
return Object.values(this.repositories).some(experiments =>
439-
experiments.hasRunningQueuedExperiment()
418+
experiments.hasRunningExperiment()
440419
)
441420
}
442421

extension/src/extension.ts

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import { definedAndNonEmpty } from './util/array'
5151
import { stopProcesses } from './processExecution'
5252
import { Flag } from './cli/dvc/constants'
5353
import { LanguageClient } from './languageClient'
54+
import { collectRunningExperimentPids } from './experiments/processes/collect'
5455

5556
export class Extension extends Disposable {
5657
protected readonly internalCommands: InternalCommands
@@ -200,21 +201,21 @@ export class Extension extends Disposable {
200201
RegisteredCommands.STOP_EXPERIMENTS,
201202
async () => {
202203
const stopWatch = new StopWatch()
203-
const dvcLiveOnlyPids = await this.experiments.getDvcLiveOnlyPids()
204+
const pids = await collectRunningExperimentPids(this.getRoots())
204205
const wasRunning =
205206
this.dvcRunner.isExperimentRunning() ||
206-
definedAndNonEmpty(dvcLiveOnlyPids) ||
207-
this.experiments.hasQueuedExperimentsRunning()
207+
definedAndNonEmpty(pids) ||
208+
this.experiments.hasRunningExperiment()
208209
try {
209-
const allStopped = await Promise.all([
210-
stopProcesses(dvcLiveOnlyPids),
211-
this.dvcRunner.stop(),
210+
const processesStopped = await Promise.all([
211+
stopProcesses(pids),
212212
...this.getRoots().map(dvcRoot =>
213213
this.dvcExecutor.queueStop(dvcRoot, Flag.KILL)
214214
)
215215
])
216+
const runnerStopped = await this.dvcRunner.stop()
216217

217-
const stopped = allStopped.every(Boolean)
218+
const stopped = processesStopped.every(Boolean) || runnerStopped
218219
sendTelemetryEvent(
219220
RegisteredCommands.STOP_EXPERIMENTS,
220221
{ stopped, wasRunning },

extension/src/fileSystem/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ export const writeJson = <T extends Record<string, unknown>>(
182182
return writeFileSync(path, JSON.stringify(obj))
183183
}
184184

185-
export const getPidFromSignalFile = async (
185+
export const getPidFromFile = async (
186186
path: string
187187
): Promise<number | undefined> => {
188188
if (!exists(path)) {
@@ -200,7 +200,7 @@ export const getPidFromSignalFile = async (
200200
}
201201

202202
export const checkSignalFile = async (path: string): Promise<boolean> => {
203-
return !!(await getPidFromSignalFile(path))
203+
return !!(await getPidFromFile(path))
204204
}
205205

206206
export const pollSignalFileForProcess = async (

0 commit comments

Comments
 (0)