Skip to content

Commit 8580423

Browse files
committed
tests(//py): Restructing the nox file
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 80b906e commit 8580423

File tree

1 file changed

+238
-108
lines changed

1 file changed

+238
-108
lines changed

noxfile.py

Lines changed: 238 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from distutils.command.clean import clean
12
import nox
23
import os
34

@@ -8,7 +9,7 @@
89
# TOP_DIR
910
TOP_DIR=os.path.dirname(os.path.realpath(__file__)) if not 'TOP_DIR' in os.environ else os.environ["TOP_DIR"]
1011

11-
nox.options.sessions = ["developer_tests-3"]
12+
nox.options.sessions = ["l0_api_tests-3"]
1213

1314
def install_deps(session):
1415
print("Installing deps")
@@ -30,31 +31,6 @@ def install_torch_trt(session):
3031
session.chdir(os.path.join(TOP_DIR, "py"))
3132
session.run("python", "setup.py", "develop")
3233

33-
def run_base_tests(session, use_host_env=False):
34-
print("Running basic tests")
35-
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
36-
tests = [
37-
"test_api.py",
38-
"test_to_backend_api.py"
39-
]
40-
for test in tests:
41-
if use_host_env:
42-
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
43-
else:
44-
session.run_always("python", test)
45-
46-
47-
# Install the latest build of torch-tensorrt
48-
@nox.session(python=["3"], reuse_venv=True)
49-
def developer_tests(session):
50-
"""Basic set of tests that need to pass for code to get merged"""
51-
install_deps(session)
52-
install_torch_trt(session)
53-
download_models(session)
54-
run_base_tests(session)
55-
56-
# Download the dataset
57-
@nox.session(python=["3"], reuse_venv=True)
5834
def download_datasets(session):
5935
print("Downloading dataset to path", os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
6036
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
@@ -68,98 +44,70 @@ def download_datasets(session):
6844
os.path.join(TOP_DIR, 'tests/accuracy/datasets/data/cidar-10-batches-bin'),
6945
external=True)
7046

71-
# Download the model
72-
@nox.session(python=["3"], reuse_venv=True)
73-
def download_test_models(session):
74-
download_models(session, use_host_env=True)
75-
76-
# Train the model
77-
@nox.session(python=["3"], reuse_venv=True)
78-
def train_model(session):
47+
def train_model(session, use_host_env=False):
7948
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
80-
session.run_always('python',
81-
'main.py',
82-
'--lr', '0.01',
83-
'--batch-size', '128',
84-
'--drop-ratio', '0.15',
85-
'--ckpt-dir', 'vgg16_ckpts',
86-
'--epochs', '25',
87-
env={'PYTHONPATH': PYT_PATH})
88-
89-
# Export model
90-
session.run_always('python',
49+
if use_host_env:
50+
session.run_always('python',
51+
'main.py',
52+
'--lr', '0.01',
53+
'--batch-size', '128',
54+
'--drop-ratio', '0.15',
55+
'--ckpt-dir', 'vgg16_ckpts',
56+
'--epochs', '25',
57+
env={'PYTHONPATH': PYT_PATH})
58+
59+
session.run_always('python',
9160
'export_ckpt.py',
92-
'vgg16_ckpts/ckpt_epoch25.pth',
93-
env={'PYTHONPATH': PYT_PATH})
61+
'vgg16_ckpts/ckpt_epoch25.pth')
62+
else:
63+
session.run_always('python',
64+
'main.py',
65+
'--lr', '0.01',
66+
'--batch-size', '128',
67+
'--drop-ratio', '0.15',
68+
'--ckpt-dir', 'vgg16_ckpts',
69+
'--epochs', '25')
9470

95-
# Finetune the model
96-
@nox.session(python=["3"], reuse_venv=True)
97-
def finetune_model(session):
71+
session.run_always('python',
72+
'export_ckpt.py',
73+
'vgg16_ckpts/ckpt_epoch25.pth')
74+
75+
def finetune_model(session, use_host_env=False):
9876
# Install pytorch-quantization dependency
9977
session.install('pytorch-quantization', '--extra-index-url', 'https://pypi.ngc.nvidia.com')
100-
10178
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
102-
session.run_always('python',
103-
'finetune_qat.py',
104-
'--lr', '0.01',
105-
'--batch-size', '128',
106-
'--drop-ratio', '0.15',
107-
'--ckpt-dir', 'vgg16_ckpts',
108-
'--start-from', '25',
109-
'--epochs', '26',
110-
env={'PYTHONPATH': PYT_PATH})
111-
112-
# Export model
113-
session.run_always('python',
114-
'export_qat.py',
115-
'vgg16_ckpts/ckpt_epoch26.pth',
116-
env={'PYTHONPATH': PYT_PATH})
117-
118-
# Run PTQ tests
119-
@nox.session(python=["3"], reuse_venv=True)
120-
def ptq_test(session):
121-
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
122-
session.run_always('cp', '-rf',
123-
os.path.join(TOP_DIR, 'examples/int8/training/vgg16', 'trained_vgg16.jit.pt'),
124-
'.',
125-
external=True)
126-
tests = [
127-
'test_ptq_dataloader_calibrator.py',
128-
'test_ptq_to_backend.py',
129-
'test_ptq_trt_calibrator.py'
130-
]
131-
for test in tests:
132-
session.run_always('python', test,
133-
env={'PYTHONPATH': PYT_PATH})
13479

135-
# Run QAT tests
136-
@nox.session(python=["3"], reuse_venv=True)
137-
def qat_test(session):
138-
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
139-
session.run_always('cp', '-rf',
140-
os.path.join(TOP_DIR, 'examples/int8/training/vgg16', 'trained_vgg16_qat.jit.pt'),
141-
'.',
142-
external=True)
143-
144-
session.run_always('python',
145-
'test_qat_trt_accuracy.py',
146-
env={'PYTHONPATH': PYT_PATH})
80+
if use_host_env:
81+
session.run_always('python',
82+
'finetune_qat.py',
83+
'--lr', '0.01',
84+
'--batch-size', '128',
85+
'--drop-ratio', '0.15',
86+
'--ckpt-dir', 'vgg16_ckpts',
87+
'--start-from', '25',
88+
'--epochs', '26',
89+
env={'PYTHONPATH': PYT_PATH})
14790

148-
# Run Python API tests
149-
@nox.session(python=["3"], reuse_venv=True)
150-
def api_test(session):
151-
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
152-
tests = [
153-
"test_api.py",
154-
"test_to_backend_api.py"
155-
]
156-
for test in tests:
91+
# Export model
15792
session.run_always('python',
158-
test,
93+
'export_qat.py',
94+
'vgg16_ckpts/ckpt_epoch26.pth',
15995
env={'PYTHONPATH': PYT_PATH})
96+
else:
97+
session.run_always('python',
98+
'finetune_qat.py',
99+
'--lr', '0.01',
100+
'--batch-size', '128',
101+
'--drop-ratio', '0.15',
102+
'--ckpt-dir', 'vgg16_ckpts',
103+
'--start-from', '25',
104+
'--epochs', '26')
105+
106+
# Export model
107+
session.run_always('python',
108+
'export_qat.py',
109+
'vgg16_ckpts/ckpt_epoch26.pth')
160110

161-
# Clean up
162-
@nox.session(reuse_venv=True)
163111
def cleanup(session):
164112
target = [
165113
'examples/int8/training/vgg16/*.jit.pt',
@@ -173,4 +121,186 @@ def cleanup(session):
173121
target = ' '.join(x for x in [os.path.join(TOP_DIR, i) for i in target])
174122
session.run_always('bash', '-c',
175123
str('rm -rf ') + target,
176-
external=True)
124+
external=True)
125+
126+
def run_base_tests(session, use_host_env=False):
127+
print("Running basic tests")
128+
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
129+
tests = [
130+
"test_api.py",
131+
"test_to_backend_api.py"
132+
]
133+
for test in tests:
134+
if use_host_env:
135+
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
136+
else:
137+
session.run_always("python", test)
138+
139+
def run_accuracy_tests(session, use_host_env=False):
140+
print("Running accuracy tests")
141+
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
142+
tests = []
143+
for test in tests:
144+
if use_host_env:
145+
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
146+
else:
147+
session.run_always("python", test)
148+
149+
def run_int8_accuracy_tests(session, use_host_env=False):
150+
print("Running accuracy tests")
151+
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
152+
tests = [
153+
"test_ptq_dataloader.py",
154+
"test_ptq_to_backend.py",
155+
"test_qat_trt_accuracy",
156+
]
157+
for test in tests:
158+
if use_host_env:
159+
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
160+
else:
161+
session.run_always("python", test)
162+
163+
def run_trt_compatibility_tests(session, use_host_env=False):
164+
print("Running TensorRT compatibility tests")
165+
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
166+
tests = [
167+
"test_trt_intercompatibilty.py",
168+
"test_ptq_trt_calibrator.py",
169+
]
170+
for test in tests:
171+
if use_host_env:
172+
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
173+
else:
174+
session.run_always("python", test)
175+
176+
def run_dla_tests(session, use_host_env=False):
177+
print("Running DLA tests")
178+
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
179+
tests = [
180+
"test_api_dla.py",
181+
]
182+
for test in tests:
183+
if use_host_env:
184+
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
185+
else:
186+
session.run_always("python", test)
187+
188+
def run_multi_gpu_tests(session, use_host_env=False):
189+
print("Running multi GPU tests")
190+
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
191+
tests = [
192+
"test_multi_gpu.py",
193+
]
194+
for test in tests:
195+
if use_host_env:
196+
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
197+
else:
198+
session.run_always("python", test)
199+
200+
def run_l0_api_tests(session, use_host_env=False):
201+
if not use_host_env:
202+
install_deps(session)
203+
install_torch_trt(session)
204+
download_models(session, use_host_env)
205+
run_base_tests(session, use_host_env)
206+
cleanup(session)
207+
208+
def run_l0_dla_tests(session, use_host_env=False):
209+
if not use_host_env:
210+
install_deps(session)
211+
install_torch_trt(session)
212+
download_models(session, use_host_env)
213+
run_base_tests(session, use_host_env)
214+
cleanup(session)
215+
216+
def run_l1_accuracy_tests(session, use_host_env=False):
217+
if not use_host_env:
218+
install_deps(session)
219+
install_torch_trt(session)
220+
download_models(session, use_host_env)
221+
download_datasets(session, use_host_env)
222+
train_model(session, use_host_env)
223+
run_accuracy_tests(session, use_host_env)
224+
cleanup(session)
225+
226+
def run_l1_int8_accuracy_tests(session, use_host_env=False):
227+
if not use_host_env:
228+
install_deps(session)
229+
install_torch_trt(session)
230+
download_models(session, use_host_env)
231+
download_datasets(session, use_host_env)
232+
train_model(session, use_host_env)
233+
finetune_model(session, use_host_env)
234+
run_int8_accuracy_tests(session, use_host_env)
235+
cleanup(session)
236+
237+
def run_l2_trt_compatibility_tests(session, use_host_env=False):
238+
if not use_host_env:
239+
install_deps(session)
240+
install_torch_trt(session)
241+
download_models(session, use_host_env)
242+
run_trt_compatibility_tests(session, use_host_env)
243+
cleanup(session)
244+
245+
def run_l2_multi_gpu_tests(session, use_host_env=False):
246+
if not use_host_env:
247+
install_deps(session)
248+
install_torch_trt(session)
249+
download_models(session, use_host_env)
250+
run_multi_gpu_tests(session, use_host_env)
251+
cleanup(session)
252+
253+
@nox.session(python=["3"], reuse_venv=True)
254+
def l0_api_tests(session):
255+
"""When a developer needs to check correctness for a PR or something"""
256+
run_l0_api_tests(session, use_host_env=False)
257+
258+
@nox.session(python=["3"], reuse_venv=True)
259+
def l0_api_tests_host_deps(session):
260+
"""When a developer needs to check basic api functionality using host dependencies"""
261+
run_l0_api_tests(session, use_host_env=True)
262+
263+
@nox.session(python=["3"], reuse_venv=True)
264+
def l0_dla_tests_host_deps(session):
265+
"""When a developer needs to check basic api functionality using host dependencies"""
266+
run_l0_dla_tests(session, use_host_env=True)
267+
268+
@nox.session(python=["3"], reuse_venv=True)
269+
def l1_accuracy_tests(session):
270+
"""Checking accuracy performance on various usecases"""
271+
run_l1_accuracy_tests(session, use_host_env=False)
272+
273+
@nox.session(python=["3"], reuse_venv=True)
274+
def l1_accuracy_tests_host_deps(session):
275+
"""Checking accuracy performance on various usecases using host dependencies"""
276+
run_l1_accuracy_tests(session, use_host_env=True)
277+
278+
@nox.session(python=["3"], reuse_venv=True)
279+
def l1_int8_accuracy_tests(session):
280+
"""Checking accuracy performance on various usecases"""
281+
run_l1_int8_accuracy_tests(session, use_host_env=False)
282+
283+
@nox.session(python=["3"], reuse_venv=True)
284+
def l1_int8_accuracy_tests_host_deps(session):
285+
"""Checking accuracy performance on various usecases using host dependencies"""
286+
run_l1_int8_accuracy_tests(session, use_host_env=True)
287+
288+
@nox.session(python=["3"], reuse_venv=True)
289+
def l2_trt_compatibility_tests(session):
290+
"""Makes sure that TensorRT Python and Torch-TensorRT can work together"""
291+
run_l2_trt_compatibility_tests(session, use_host_env=False)
292+
293+
@nox.session(python=["3"], reuse_venv=True)
294+
def l2_trt_compatibility_tests_host_deps(session):
295+
"""Makes sure that TensorRT Python and Torch-TensorRT can work together using host dependencies"""
296+
run_l2_trt_compatibility_tests(session, use_host_env=True)
297+
298+
@nox.session(python=["3"], reuse_venv=True)
299+
def l2_multi_gpu_tests(session):
300+
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
301+
run_l2_multi_gpu_tests(session, use_host_env=False)
302+
303+
@nox.session(python=["3"], reuse_venv=True)
304+
def l2_multi_gpu_tests_host_deps(session):
305+
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
306+
run_l2_multi_gpu_tests(session, use_host_env=True)

0 commit comments

Comments
 (0)