Skip to content

Commit 80b906e

Browse files
committed
tests(nox): Adding a developer test suite
After installing nox, the whole test stack can be run by simply running `nox`. This will create a venv, install the correct version of pytorch and tests deps, build and install torch-tensorrt download models and run the developer test suite. The env is persistent so the step up steps are cached Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 7191959 commit 80b906e

File tree

1 file changed

+71
-29
lines changed

1 file changed

+71
-29
lines changed

noxfile.py

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,54 @@
44
# Use system installed Python packages
55
PYT_PATH='/opt/conda/lib/python3.8/site-packages' if not 'PYT_PATH' in os.environ else os.environ["PYT_PATH"]
66

7-
# Root directory for torch_tensorrt. Set according to docker container by default
8-
TOP_DIR='/torchtrt' if not 'TOP_DIR' in os.environ else os.environ["TOP_DIR"]
7+
# Set the root directory to the directory of the noxfile unless the user wants to
8+
# TOP_DIR
9+
TOP_DIR=os.path.dirname(os.path.realpath(__file__)) if not 'TOP_DIR' in os.environ else os.environ["TOP_DIR"]
10+
11+
nox.options.sessions = ["developer_tests-3"]
12+
13+
def install_deps(session):
14+
print("Installing deps")
15+
session.install("-r", os.path.join(TOP_DIR, "py", "requirements.txt"))
16+
session.install("-r", os.path.join(TOP_DIR, "tests", "py", "requirements.txt"))
17+
18+
def download_models(session, use_host_env=False):
19+
print("Downloading test models")
20+
session.install('timm')
21+
print(TOP_DIR)
22+
session.chdir(os.path.join(TOP_DIR, "tests", "modules"))
23+
if use_host_env:
24+
session.run_always('python', 'hub.py', env={'PYTHONPATH': PYT_PATH})
25+
else:
26+
session.run_always('python', 'hub.py')
27+
28+
def install_torch_trt(session):
29+
print("Installing latest torch-tensorrt build")
30+
session.chdir(os.path.join(TOP_DIR, "py"))
31+
session.run("python", "setup.py", "develop")
32+
33+
def run_base_tests(session, use_host_env=False):
34+
print("Running basic tests")
35+
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
36+
tests = [
37+
"test_api.py",
38+
"test_to_backend_api.py"
39+
]
40+
for test in tests:
41+
if use_host_env:
42+
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
43+
else:
44+
session.run_always("python", test)
45+
46+
47+
# Install the latest build of torch-tensorrt
48+
@nox.session(python=["3"], reuse_venv=True)
49+
def developer_tests(session):
50+
"""Basic set of tests that need to pass for code to get merged"""
51+
install_deps(session)
52+
install_torch_trt(session)
53+
download_models(session)
54+
run_base_tests(session)
955

1056
# Download the dataset
1157
@nox.session(python=["3"], reuse_venv=True)
@@ -14,33 +60,29 @@ def download_datasets(session):
1460
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
1561
session.run_always('wget', 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', external=True)
1662
session.run_always('tar', '-xvzf', 'cifar-10-binary.tar.gz', external=True)
17-
session.run_always('mkdir', '-p',
63+
session.run_always('mkdir', '-p',
1864
os.path.join(TOP_DIR, 'tests/accuracy/datasets/data'),
1965
external=True)
20-
session.run_always('cp', '-rpf',
66+
session.run_always('cp', '-rpf',
2167
os.path.join(TOP_DIR, 'examples/int8/training/vgg16/cifar-10-batches-bin'),
2268
os.path.join(TOP_DIR, 'tests/accuracy/datasets/data/cidar-10-batches-bin'),
2369
external=True)
2470

2571
# Download the model
2672
@nox.session(python=["3"], reuse_venv=True)
27-
def download_models(session):
28-
session.install('timm')
29-
session.chdir('tests/modules')
30-
session.run_always('python',
31-
'hub.py',
32-
env={'PYTHONPATH': PYT_PATH})
73+
def download_test_models(session):
74+
download_models(session, use_host_env=True)
3375

3476
# Train the model
3577
@nox.session(python=["3"], reuse_venv=True)
3678
def train_model(session):
3779
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
38-
session.run_always('python',
39-
'main.py',
40-
'--lr', '0.01',
41-
'--batch-size', '128',
42-
'--drop-ratio', '0.15',
43-
'--ckpt-dir', 'vgg16_ckpts',
80+
session.run_always('python',
81+
'main.py',
82+
'--lr', '0.01',
83+
'--batch-size', '128',
84+
'--drop-ratio', '0.15',
85+
'--ckpt-dir', 'vgg16_ckpts',
4486
'--epochs', '25',
4587
env={'PYTHONPATH': PYT_PATH})
4688

@@ -57,17 +99,17 @@ def finetune_model(session):
5799
session.install('pytorch-quantization', '--extra-index-url', 'https://pypi.ngc.nvidia.com')
58100

59101
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
60-
session.run_always('python',
61-
'finetune_qat.py',
62-
'--lr', '0.01',
63-
'--batch-size', '128',
64-
'--drop-ratio', '0.15',
65-
'--ckpt-dir', 'vgg16_ckpts',
66-
'--start-from', '25',
102+
session.run_always('python',
103+
'finetune_qat.py',
104+
'--lr', '0.01',
105+
'--batch-size', '128',
106+
'--drop-ratio', '0.15',
107+
'--ckpt-dir', 'vgg16_ckpts',
108+
'--start-from', '25',
67109
'--epochs', '26',
68110
env={'PYTHONPATH': PYT_PATH})
69-
70-
# Export model
111+
112+
# Export model
71113
session.run_always('python',
72114
'export_qat.py',
73115
'vgg16_ckpts/ckpt_epoch26.pth',
@@ -77,8 +119,8 @@ def finetune_model(session):
77119
@nox.session(python=["3"], reuse_venv=True)
78120
def ptq_test(session):
79121
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
80-
session.run_always('cp', '-rf',
81-
os.path.join(TOP_DIR, 'examples/int8/training/vgg16', 'trained_vgg16.jit.pt'),
122+
session.run_always('cp', '-rf',
123+
os.path.join(TOP_DIR, 'examples/int8/training/vgg16', 'trained_vgg16.jit.pt'),
82124
'.',
83125
external=True)
84126
tests = [
@@ -94,8 +136,8 @@ def ptq_test(session):
94136
@nox.session(python=["3"], reuse_venv=True)
95137
def qat_test(session):
96138
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
97-
session.run_always('cp', '-rf',
98-
os.path.join(TOP_DIR, 'examples/int8/training/vgg16', 'trained_vgg16_qat.jit.pt'),
139+
session.run_always('cp', '-rf',
140+
os.path.join(TOP_DIR, 'examples/int8/training/vgg16', 'trained_vgg16_qat.jit.pt'),
99141
'.',
100142
external=True)
101143

0 commit comments

Comments
 (0)