30
30
if USE_HOST_DEPS :
31
31
print ("Using dependencies from host python" )
32
32
33
+ # Set epochs to train VGG model for accuracy tests
34
+ EPOCHS = 25
35
+
33
36
SUPPORTED_PYTHON_VERSIONS = ["3.7" , "3.8" , "3.9" , "3.10" ]
34
37
35
38
nox .options .sessions = [
36
39
"l0_api_tests-" + "{}.{}" .format (sys .version_info .major , sys .version_info .minor )
37
40
]
38
41
39
-
40
42
def install_deps (session ):
41
43
print ("Installing deps" )
42
44
session .install ("-r" , os .path .join (TOP_DIR , "py" , "requirements.txt" ))
@@ -63,31 +65,6 @@ def install_torch_trt(session):
63
65
session .run ("python" , "setup.py" , "develop" )
64
66
65
67
66
- def download_datasets (session ):
67
- print (
68
- "Downloading dataset to path" ,
69
- os .path .join (TOP_DIR , "examples/int8/training/vgg16" ),
70
- )
71
- session .chdir (os .path .join (TOP_DIR , "examples/int8/training/vgg16" ))
72
- session .run_always (
73
- "wget" , "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz" , external = True
74
- )
75
- session .run_always ("tar" , "-xvzf" , "cifar-10-binary.tar.gz" , external = True )
76
- session .run_always (
77
- "mkdir" ,
78
- "-p" ,
79
- os .path .join (TOP_DIR , "tests/accuracy/datasets/data" ),
80
- external = True ,
81
- )
82
- session .run_always (
83
- "cp" ,
84
- "-rpf" ,
85
- os .path .join (TOP_DIR , "examples/int8/training/vgg16/cifar-10-batches-bin" ),
86
- os .path .join (TOP_DIR , "tests/accuracy/datasets/data/cidar-10-batches-bin" ),
87
- external = True ,
88
- )
89
-
90
-
91
68
def train_model (session ):
92
69
session .chdir (os .path .join (TOP_DIR , "examples/int8/training/vgg16" ))
93
70
session .install ("-r" , "requirements.txt" )
@@ -107,14 +84,14 @@ def train_model(session):
107
84
"--ckpt-dir" ,
108
85
"vgg16_ckpts" ,
109
86
"--epochs" ,
110
- "25" ,
87
+ str ( EPOCHS ) ,
111
88
env = {"PYTHONPATH" : PYT_PATH },
112
89
)
113
90
114
91
session .run_always (
115
92
"python" ,
116
93
"export_ckpt.py" ,
117
- "vgg16_ckpts/ckpt_epoch25 .pth" ,
94
+ "vgg16_ckpts/ckpt_epoch" + str ( EPOCHS ) + " .pth" ,
118
95
env = {"PYTHONPATH" : PYT_PATH },
119
96
)
120
97
else :
@@ -130,10 +107,10 @@ def train_model(session):
130
107
"--ckpt-dir" ,
131
108
"vgg16_ckpts" ,
132
109
"--epochs" ,
133
- "25" ,
110
+ str ( EPOCHS ) ,
134
111
)
135
112
136
- session .run_always ("python" , "export_ckpt.py" , "vgg16_ckpts/ckpt_epoch25 .pth" )
113
+ session .run_always ("python" , "export_ckpt.py" , "vgg16_ckpts/ckpt_epoch" + str ( EPOCHS ) + " .pth" )
137
114
138
115
139
116
def finetune_model (session ):
@@ -156,17 +133,17 @@ def finetune_model(session):
156
133
"--ckpt-dir" ,
157
134
"vgg16_ckpts" ,
158
135
"--start-from" ,
159
- "25" ,
136
+ str ( EPOCHS ) ,
160
137
"--epochs" ,
161
- "26" ,
138
+ str ( EPOCHS + 1 ) ,
162
139
env = {"PYTHONPATH" : PYT_PATH },
163
140
)
164
141
165
142
# Export model
166
143
session .run_always (
167
144
"python" ,
168
145
"export_qat.py" ,
169
- "vgg16_ckpts/ckpt_epoch26 .pth" ,
146
+ "vgg16_ckpts/ckpt_epoch" + str ( EPOCHS + 1 ) + " .pth" ,
170
147
env = {"PYTHONPATH" : PYT_PATH },
171
148
)
172
149
else :
@@ -182,13 +159,13 @@ def finetune_model(session):
182
159
"--ckpt-dir" ,
183
160
"vgg16_ckpts" ,
184
161
"--start-from" ,
185
- "25" ,
162
+ str ( EPOCHS ) ,
186
163
"--epochs" ,
187
- "26" ,
164
+ str ( EPOCHS + 1 ) ,
188
165
)
189
166
190
167
# Export model
191
- session .run_always ("python" , "export_qat.py" , "vgg16_ckpts/ckpt_epoch26 .pth" )
168
+ session .run_always ("python" , "export_qat.py" , "vgg16_ckpts/ckpt_epoch" + str ( EPOCHS + 1 ) + " .pth" )
192
169
193
170
194
171
def cleanup (session ):
@@ -209,7 +186,7 @@ def run_base_tests(session):
209
186
print ("Running basic tests" )
210
187
session .chdir (os .path .join (TOP_DIR , "tests/py" ))
211
188
tests = [
212
- "api" ,
189
+ "api/test_e2e_behavior.py " ,
213
190
"integrations/test_to_backend_api.py" ,
214
191
]
215
192
for test in tests :
@@ -218,6 +195,18 @@ def run_base_tests(session):
218
195
else :
219
196
session .run_always ("pytest" , test )
220
197
198
+ def run_model_tests (session ):
199
+ print ("Running model tests" )
200
+ session .chdir (os .path .join (TOP_DIR , "tests/py" ))
201
+ tests = [
202
+ "models" ,
203
+ ]
204
+ for test in tests :
205
+ if USE_HOST_DEPS :
206
+ session .run_always ("pytest" , test , env = {"PYTHONPATH" : PYT_PATH })
207
+ else :
208
+ session .run_always ("pytest" , test )
209
+
221
210
222
211
def run_accuracy_tests (session ):
223
212
print ("Running accuracy tests" )
@@ -268,8 +257,8 @@ def run_trt_compatibility_tests(session):
268
257
copy_model (session )
269
258
session .chdir (os .path .join (TOP_DIR , "tests/py" ))
270
259
tests = [
271
- "test_trt_intercompatibility.py" ,
272
- " test_ptq_trt_calibrator.py" ,
260
+ "integrations/ test_trt_intercompatibility.py" ,
261
+ #"ptq/ test_ptq_trt_calibrator.py",
273
262
]
274
263
for test in tests :
275
264
if USE_HOST_DEPS :
@@ -282,7 +271,7 @@ def run_dla_tests(session):
282
271
print ("Running DLA tests" )
283
272
session .chdir (os .path .join (TOP_DIR , "tests/py" ))
284
273
tests = [
285
- "test_api_dla.py" ,
274
+ "hw/ test_api_dla.py" ,
286
275
]
287
276
for test in tests :
288
277
if USE_HOST_DEPS :
@@ -295,7 +284,7 @@ def run_multi_gpu_tests(session):
295
284
print ("Running multi GPU tests" )
296
285
session .chdir (os .path .join (TOP_DIR , "tests/py" ))
297
286
tests = [
298
- "test_multi_gpu.py" ,
287
+ "hw/ test_multi_gpu.py" ,
299
288
]
300
289
for test in tests :
301
290
if USE_HOST_DEPS :
@@ -321,22 +310,18 @@ def run_l0_dla_tests(session):
321
310
run_base_tests (session )
322
311
cleanup (session )
323
312
324
-
325
- def run_l1_accuracy_tests (session ):
313
+ def run_l1_model_tests (session ):
326
314
if not USE_HOST_DEPS :
327
315
install_deps (session )
328
316
install_torch_trt (session )
329
- download_datasets (session )
330
- train_model (session )
331
- run_accuracy_tests (session )
317
+ download_models (session )
318
+ run_model_tests (session )
332
319
cleanup (session )
333
320
334
-
335
321
def run_l1_int8_accuracy_tests (session ):
336
322
if not USE_HOST_DEPS :
337
323
install_deps (session )
338
324
install_torch_trt (session )
339
- download_datasets (session )
340
325
train_model (session )
341
326
finetune_model (session )
342
327
run_int8_accuracy_tests (session )
@@ -348,7 +333,6 @@ def run_l2_trt_compatibility_tests(session):
348
333
install_deps (session )
349
334
install_torch_trt (session )
350
335
download_models (session )
351
- download_datasets (session )
352
336
train_model (session )
353
337
run_trt_compatibility_tests (session )
354
338
cleanup (session )
@@ -368,18 +352,15 @@ def l0_api_tests(session):
368
352
"""When a developer needs to check correctness for a PR or something"""
369
353
run_l0_api_tests (session )
370
354
371
-
372
355
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
373
356
def l0_dla_tests (session ):
374
357
"""When a developer needs to check basic api functionality using host dependencies"""
375
358
run_l0_dla_tests (session )
376
359
377
-
378
360
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
379
- def l1_accuracy_tests (session ):
380
- """Checking accuracy performance on various usecases"""
381
- run_l1_accuracy_tests (session )
382
-
361
+ def l1_model_tests (session ):
362
+ """When a developer needs to check correctness for a PR or something"""
363
+ run_l1_model_tests (session )
383
364
384
365
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
385
366
def l1_int8_accuracy_tests (session ):
@@ -397,13 +378,3 @@ def l2_trt_compatibility_tests(session):
397
378
def l2_multi_gpu_tests (session ):
398
379
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
399
380
run_l2_multi_gpu_tests (session )
400
-
401
-
402
- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
403
- def download_test_models (session ):
404
- """Grab all the models needed for testing"""
405
- try :
406
- import torch
407
- except ModuleNotFoundError :
408
- install_deps (session )
409
- download_models (session )
0 commit comments