Skip to content

Commit 85e0e66

Browse files
authored
Add jupyter notebook script to run in dvc.yaml (#3240)
* Add jupyter notebook script to run in dvc.yaml * Add tests * Edit tests for windows * Apply review comments
1 parent c83d074 commit 85e0e66

File tree

2 files changed

+95
-5
lines changed

2 files changed

+95
-5
lines changed

extension/src/fileSystem/index.test.ts

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,31 @@
11
import { join, relative, resolve } from 'path'
2-
import { ensureDirSync, remove } from 'fs-extra'
2+
import { appendFileSync, ensureDirSync, ensureFileSync, remove } from 'fs-extra'
33
import {
44
exists,
55
findAbsoluteDvcRootPath,
66
findDvcRootPaths,
77
isDirectory,
88
isSameOrChild,
9-
getModifiedTime
9+
getModifiedTime,
10+
findOrCreateDvcYamlFile,
11+
scriptCommand
1012
} from '.'
1113
import { dvcDemoPath } from '../test/util'
1214
import { DOT_DVC } from '../cli/dvc/constants'
1315

1416
jest.mock('../cli/dvc/reader')
17+
jest.mock('fs-extra', () => {
18+
const actualModule = jest.requireActual('fs-extra')
19+
return {
20+
__esModule: true,
21+
...actualModule,
22+
appendFileSync: jest.fn(),
23+
ensureFileSync: jest.fn()
24+
}
25+
})
26+
27+
const mockedAppendFileSync = jest.mocked(appendFileSync)
28+
const mockedEnsureFileSync = jest.mocked(ensureFileSync)
1529

1630
beforeEach(() => {
1731
jest.resetAllMocks()
@@ -151,3 +165,70 @@ describe('getModifiedTime', () => {
151165
expect(epoch).toBeGreaterThan(1640995200000)
152166
})
153167
})
168+
169+
describe('findOrCreateDvcYamlFile', () => {
170+
it('should make sure a dvc.yaml file exists', () => {
171+
const cwd = '/cwd'
172+
findOrCreateDvcYamlFile(cwd, '/my/training/script.py')
173+
174+
expect(mockedEnsureFileSync).toHaveBeenCalledWith(`${cwd}/dvc.yaml`)
175+
})
176+
177+
it('should add the training script as a train stage in the dvc.yaml file', () => {
178+
const cwd = '/cwd'
179+
findOrCreateDvcYamlFile(cwd, '/my/training/script.py')
180+
181+
expect(mockedAppendFileSync).toHaveBeenCalledWith(
182+
`${cwd}/dvc.yaml`,
183+
expect.stringMatching(/^\s+stages:\s+train:/)
184+
)
185+
})
186+
187+
it('should add the training script as a relative path to the cwd', () => {
188+
findOrCreateDvcYamlFile(
189+
'/dir/my_project/',
190+
'/dir/my_project/src/training/train.py'
191+
)
192+
193+
expect(mockedAppendFileSync).toHaveBeenCalledWith(
194+
expect.anything(),
195+
expect.stringContaining(join('src', 'training', 'train.py'))
196+
)
197+
198+
findOrCreateDvcYamlFile(
199+
'/dir/my_project/',
200+
'/dir/my_other_project/train.py'
201+
)
202+
203+
expect(mockedAppendFileSync).toHaveBeenCalledWith(
204+
expect.anything(),
205+
expect.stringContaining(join('..', 'my_other_project', 'train.py'))
206+
)
207+
})
208+
209+
it('should use the jupyter nbconvert command if the training script is a Jupyter notebook', () => {
210+
findOrCreateDvcYamlFile('/', '/train.ipynb')
211+
212+
expect(mockedAppendFileSync).toHaveBeenCalledWith(
213+
expect.anything(),
214+
expect.stringContaining(scriptCommand.JUPYTER)
215+
)
216+
expect(mockedAppendFileSync).not.toHaveBeenCalledWith(
217+
expect.anything(),
218+
expect.stringContaining(scriptCommand.PYTHON)
219+
)
220+
})
221+
222+
it('should use the python command if the training script is not a Jupyter notebook', () => {
223+
findOrCreateDvcYamlFile('/', '/train.py')
224+
225+
expect(mockedAppendFileSync).not.toHaveBeenCalledWith(
226+
expect.anything(),
227+
expect.stringContaining(scriptCommand.JUPYTER)
228+
)
229+
expect(mockedAppendFileSync).toHaveBeenCalledWith(
230+
expect.anything(),
231+
expect.stringContaining(scriptCommand.PYTHON)
232+
)
233+
})
234+
})

extension/src/fileSystem/index.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { basename, extname, join, relative, resolve, sep } from 'path'
1+
import { basename, extname, join, parse, relative, resolve, sep } from 'path'
22
import {
33
appendFileSync,
44
ensureFileSync,
@@ -131,16 +131,25 @@ export const isAnyDvcYaml = (path?: string): boolean =>
131131
basename(path) === 'dvc.yaml')
132132
)
133133

134+
export const scriptCommand = {
135+
JUPYTER: 'jupyter nbconvert --to notebook --inplace --execute',
136+
PYTHON: 'python'
137+
}
138+
134139
export const findOrCreateDvcYamlFile = (
135140
cwd: string,
136141
trainingScript: string
137142
) => {
138143
const dvcYamlPath = `${cwd}/dvc.yaml`
139144
ensureFileSync(dvcYamlPath)
140145

141-
const pipeline = `stages:
146+
const isNotebook = parse(trainingScript).ext === '.ipynb'
147+
const command = isNotebook ? scriptCommand.JUPYTER : scriptCommand.PYTHON
148+
149+
const pipeline = `
150+
stages:
142151
train:
143-
cmd: python ${relative(cwd, trainingScript)}`
152+
cmd: ${command} ${relative(cwd, trainingScript)}`
144153
return appendFileSync(dvcYamlPath, pipeline)
145154
}
146155

0 commit comments

Comments
 (0)