Skip to content

Commit beeac7c

Browse files
committed
refactor: Refactor nox file testing
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent c6f3103 commit beeac7c

File tree

12 files changed

+75
-89
lines changed

12 files changed

+75
-89
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ commands:
435435
mkdir -p /tmp/artifacts/test_results
436436
cd tests/py
437437
pytest --junitxml=/tmp/artifacts/test_results/api/api_test_results.xml api/
438+
pytest --junitxml=/tmp/artifacts/test_results/models/models_test_results.xml models/
438439
pytest --junitxml=/tmp/artifacts/test_results/integrations/integrations_test_results.xml integrations/
439440
cd ~/project
440441

noxfile.py

Lines changed: 36 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030
if USE_HOST_DEPS:
3131
print("Using dependencies from host python")
3232

33+
# Set epochs to train VGG model for accuracy tests
34+
EPOCHS=25
35+
3336
SUPPORTED_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10"]
3437

3538
nox.options.sessions = [
3639
"l0_api_tests-" + "{}.{}".format(sys.version_info.major, sys.version_info.minor)
3740
]
3841

39-
4042
def install_deps(session):
4143
print("Installing deps")
4244
session.install("-r", os.path.join(TOP_DIR, "py", "requirements.txt"))
@@ -63,31 +65,6 @@ def install_torch_trt(session):
6365
session.run("python", "setup.py", "develop")
6466

6567

66-
def download_datasets(session):
67-
print(
68-
"Downloading dataset to path",
69-
os.path.join(TOP_DIR, "examples/int8/training/vgg16"),
70-
)
71-
session.chdir(os.path.join(TOP_DIR, "examples/int8/training/vgg16"))
72-
session.run_always(
73-
"wget", "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", external=True
74-
)
75-
session.run_always("tar", "-xvzf", "cifar-10-binary.tar.gz", external=True)
76-
session.run_always(
77-
"mkdir",
78-
"-p",
79-
os.path.join(TOP_DIR, "tests/accuracy/datasets/data"),
80-
external=True,
81-
)
82-
session.run_always(
83-
"cp",
84-
"-rpf",
85-
os.path.join(TOP_DIR, "examples/int8/training/vgg16/cifar-10-batches-bin"),
86-
os.path.join(TOP_DIR, "tests/accuracy/datasets/data/cidar-10-batches-bin"),
87-
external=True,
88-
)
89-
90-
9168
def train_model(session):
9269
session.chdir(os.path.join(TOP_DIR, "examples/int8/training/vgg16"))
9370
session.install("-r", "requirements.txt")
@@ -107,14 +84,14 @@ def train_model(session):
10784
"--ckpt-dir",
10885
"vgg16_ckpts",
10986
"--epochs",
110-
"25",
87+
str(EPOCHS),
11188
env={"PYTHONPATH": PYT_PATH},
11289
)
11390

11491
session.run_always(
11592
"python",
11693
"export_ckpt.py",
117-
"vgg16_ckpts/ckpt_epoch25.pth",
94+
"vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth",
11895
env={"PYTHONPATH": PYT_PATH},
11996
)
12097
else:
@@ -130,10 +107,10 @@ def train_model(session):
130107
"--ckpt-dir",
131108
"vgg16_ckpts",
132109
"--epochs",
133-
"25",
110+
str(EPOCHS),
134111
)
135112

136-
session.run_always("python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch25.pth")
113+
session.run_always("python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth")
137114

138115

139116
def finetune_model(session):
@@ -156,17 +133,17 @@ def finetune_model(session):
156133
"--ckpt-dir",
157134
"vgg16_ckpts",
158135
"--start-from",
159-
"25",
136+
str(EPOCHS),
160137
"--epochs",
161-
"26",
138+
str(EPOCHS+1),
162139
env={"PYTHONPATH": PYT_PATH},
163140
)
164141

165142
# Export model
166143
session.run_always(
167144
"python",
168145
"export_qat.py",
169-
"vgg16_ckpts/ckpt_epoch26.pth",
146+
"vgg16_ckpts/ckpt_epoch" + str(EPOCHS+1) + ".pth",
170147
env={"PYTHONPATH": PYT_PATH},
171148
)
172149
else:
@@ -182,13 +159,13 @@ def finetune_model(session):
182159
"--ckpt-dir",
183160
"vgg16_ckpts",
184161
"--start-from",
185-
"25",
162+
str(EPOCHS),
186163
"--epochs",
187-
"26",
164+
str(EPOCHS+1),
188165
)
189166

190167
# Export model
191-
session.run_always("python", "export_qat.py", "vgg16_ckpts/ckpt_epoch26.pth")
168+
session.run_always("python", "export_qat.py", "vgg16_ckpts/ckpt_epoch" + str(EPOCHS+1) + ".pth")
192169

193170

194171
def cleanup(session):
@@ -209,7 +186,7 @@ def run_base_tests(session):
209186
print("Running basic tests")
210187
session.chdir(os.path.join(TOP_DIR, "tests/py"))
211188
tests = [
212-
"api",
189+
"api/test_e2e_behavior.py",
213190
"integrations/test_to_backend_api.py",
214191
]
215192
for test in tests:
@@ -218,6 +195,18 @@ def run_base_tests(session):
218195
else:
219196
session.run_always("pytest", test)
220197

198+
def run_model_tests(session):
199+
print("Running model tests")
200+
session.chdir(os.path.join(TOP_DIR, "tests/py"))
201+
tests = [
202+
"models",
203+
]
204+
for test in tests:
205+
if USE_HOST_DEPS:
206+
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
207+
else:
208+
session.run_always("pytest", test)
209+
221210

222211
def run_accuracy_tests(session):
223212
print("Running accuracy tests")
@@ -268,8 +257,8 @@ def run_trt_compatibility_tests(session):
268257
copy_model(session)
269258
session.chdir(os.path.join(TOP_DIR, "tests/py"))
270259
tests = [
271-
"test_trt_intercompatibility.py",
272-
"test_ptq_trt_calibrator.py",
260+
"integrations/test_trt_intercompatibility.py",
261+
#"ptq/test_ptq_trt_calibrator.py",
273262
]
274263
for test in tests:
275264
if USE_HOST_DEPS:
@@ -282,7 +271,7 @@ def run_dla_tests(session):
282271
print("Running DLA tests")
283272
session.chdir(os.path.join(TOP_DIR, "tests/py"))
284273
tests = [
285-
"test_api_dla.py",
274+
"hw/test_api_dla.py",
286275
]
287276
for test in tests:
288277
if USE_HOST_DEPS:
@@ -295,7 +284,7 @@ def run_multi_gpu_tests(session):
295284
print("Running multi GPU tests")
296285
session.chdir(os.path.join(TOP_DIR, "tests/py"))
297286
tests = [
298-
"test_multi_gpu.py",
287+
"hw/test_multi_gpu.py",
299288
]
300289
for test in tests:
301290
if USE_HOST_DEPS:
@@ -321,22 +310,18 @@ def run_l0_dla_tests(session):
321310
run_base_tests(session)
322311
cleanup(session)
323312

324-
325-
def run_l1_accuracy_tests(session):
313+
def run_l1_model_tests(session):
326314
if not USE_HOST_DEPS:
327315
install_deps(session)
328316
install_torch_trt(session)
329-
download_datasets(session)
330-
train_model(session)
331-
run_accuracy_tests(session)
317+
download_models(session)
318+
run_model_tests(session)
332319
cleanup(session)
333320

334-
335321
def run_l1_int8_accuracy_tests(session):
336322
if not USE_HOST_DEPS:
337323
install_deps(session)
338324
install_torch_trt(session)
339-
download_datasets(session)
340325
train_model(session)
341326
finetune_model(session)
342327
run_int8_accuracy_tests(session)
@@ -348,7 +333,6 @@ def run_l2_trt_compatibility_tests(session):
348333
install_deps(session)
349334
install_torch_trt(session)
350335
download_models(session)
351-
download_datasets(session)
352336
train_model(session)
353337
run_trt_compatibility_tests(session)
354338
cleanup(session)
@@ -368,18 +352,15 @@ def l0_api_tests(session):
368352
"""When a developer needs to check correctness for a PR or something"""
369353
run_l0_api_tests(session)
370354

371-
372355
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
373356
def l0_dla_tests(session):
374357
"""When a developer needs to check basic api functionality using host dependencies"""
375358
run_l0_dla_tests(session)
376359

377-
378360
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
379-
def l1_accuracy_tests(session):
380-
"""Checking accuracy performance on various usecases"""
381-
run_l1_accuracy_tests(session)
382-
361+
def l1_model_tests(session):
362+
"""When a developer needs to check correctness for a PR or something"""
363+
run_l1_model_tests(session)
383364

384365
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
385366
def l1_int8_accuracy_tests(session):
@@ -397,13 +378,3 @@ def l2_trt_compatibility_tests(session):
397378
def l2_multi_gpu_tests(session):
398379
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
399380
run_l2_multi_gpu_tests(session)
400-
401-
402-
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
403-
def download_test_models(session):
404-
"""Grab all the models needed for testing"""
405-
try:
406-
import torch
407-
except ModuleNotFoundError:
408-
install_deps(session)
409-
download_models(session)

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def _parse_input_signature(input_signature: Any):
225225

226226

227227
def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
228-
# TODO: Remove deep copy once collections does not need partial compilation
229-
compile_spec = deepcopy(compile_spec_)
228+
# TODO: Use deepcopy to support partial compilation of collections
229+
compile_spec = compile_spec_
230230
info = _ts_C.CompileSpec()
231231

232232
if len(compile_spec["inputs"]) > 0:
@@ -301,7 +301,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
301301
compile_spec["enabled_precisions"]
302302
)
303303

304-
if "calibrator" in compile_spec:
304+
if "calibrator" in compile_spec and compile_spec["calibrator"]:
305305
info.ptq_calibrator = compile_spec["calibrator"]
306306

307307
if "sparse_weights" in compile_spec:

tests/py/api/test_embed_engines.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torchvision.models as models
55
import copy
66
import timm
7-
import custom_models as cm
87
from typing import Dict
98
from utils import cosine_similarity, COSINE_THRESHOLD
109

tests/py/hw/test_api_dla.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch_tensorrt as torchtrt
33
import torch
44
import torchvision.models as models
5+
from utils import cosine_similarity, COSINE_THRESHOLD
56

67

78
class ModelTestCaseOnDLA(unittest.TestCase):
@@ -39,8 +40,8 @@ def test_compile_traced(self):
3940
}
4041

4142
trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec)
42-
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
43-
self.assertTrue(same < 2e-2)
43+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
44+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"ModelTestCaseOnDLA traced TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
4445

4546
def test_compile_script(self):
4647
compile_spec = {
@@ -55,8 +56,8 @@ def test_compile_script(self):
5556
}
5657

5758
trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec)
58-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
59-
self.assertTrue(same < 2e-2)
59+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
60+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"ModelTestCaseOnDLA scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
6061

6162

6263
def test_suite():

tests/py/hw/test_multi_gpu.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def test_compile_traced(self):
3535

3636
trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec)
3737
torchtrt.set_device(self.target_gpu)
38-
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
38+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
3939
torchtrt.set_device(0)
40-
self.assertTrue(same < 2e-3)
40+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"TestMultiGpuSwitching traced TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
4141

4242
def test_compile_script(self):
4343
torchtrt.set_device(0)
@@ -54,9 +54,10 @@ def test_compile_script(self):
5454

5555
trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec)
5656
torchtrt.set_device(self.target_gpu)
57-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
57+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
5858
torchtrt.set_device(0)
59-
self.assertTrue(same < 2e-3)
59+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"TestMultiGpuSwitching scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
60+
6061

6162

6263
class TestMultiGpuSerializeDeserializeSwitching(ModelTestCase):
@@ -89,8 +90,8 @@ def test_compile_traced(self):
8990
trt_mod = torchtrt.ts.compile(self.traced_model, **compile_spec)
9091
# Changing the device ID deliberately. It should still run on correct device ID by context switching
9192
torchtrt.set_device(1)
92-
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
93-
self.assertTrue(same < 2e-3)
93+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
94+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"TestMultiGpuSerializeDeserializeSwitching traced TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
9495

9596
def test_compile_script(self):
9697
torchtrt.set_device(0)
@@ -108,8 +109,8 @@ def test_compile_script(self):
108109
trt_mod = torchtrt.ts.compile(self.scripted_model, **compile_spec)
109110
# Changing the device ID deliberately. It should still run on correct device ID by context switching
110111
torchtrt.set_device(1)
111-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
112-
self.assertTrue(same < 2e-3)
112+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
113+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"TestMultiGpuSerializeDeserializeSwitching scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
113114

114115

115116
def test_suite():

tests/py/integrations/test_to_backend_api.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch_tensorrt as torchtrt
33
import torch
44
import torchvision.models as models
5-
5+
from utils import cosine_similarity, COSINE_THRESHOLD
66

77
class TestToBackendLowering(unittest.TestCase):
88
def setUp(self):
@@ -31,10 +31,9 @@ def setUp(self):
3131

3232
def test_to_backend_lowering(self):
3333
trt_mod = torch._C._jit_to_backend("tensorrt", self.scripted_model, self.spec)
34-
same = (
35-
(trt_mod.forward(self.input) - self.scripted_model(self.input)).abs().max()
36-
)
37-
self.assertTrue(same < 2e-3)
34+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
35+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"TestToBackendLowering TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
36+
3837

3938

4039
if __name__ == "__main__":

tests/py/integrations/test_trt_intercompatibility.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torchvision.models as models
55
import tensorrt as trt
6-
6+
from utils import cosine_similarity, COSINE_THRESHOLD
77

88
class TestPyTorchToTRTEngine(unittest.TestCase):
99
def test_pt_to_trt(self):
@@ -42,8 +42,8 @@ def test_pt_to_trt(self):
4242
device="cuda:0"
4343
).cuda_stream,
4444
)
45-
same = (out - self.ts_model(self.input)).abs().max()
46-
self.assertTrue(same < 2e-3)
45+
cos_sim = cosine_similarity(self.model(self.input), out)
46+
self.assertTrue(cos_sim > COSINE_THRESHOLD, msg=f"TestPyTorchToTRTEngine TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}")
4747

4848

4949
if __name__ == "__main__":
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)