@@ -202,6 +202,7 @@ def run_base_tests(session):
202
202
else :
203
203
session .run_always ("pytest" , test )
204
204
205
+
205
206
def run_fx_core_tests (session ):
206
207
print ("Running FX core tests" )
207
208
session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
@@ -214,6 +215,7 @@ def run_fx_core_tests(session):
214
215
else :
215
216
session .run_always ("pytest" , test )
216
217
218
+
217
219
def run_fx_converter_tests (session ):
218
220
print ("Running FX converter tests" )
219
221
session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
@@ -229,6 +231,7 @@ def run_fx_converter_tests(session):
229
231
else :
230
232
session .run_always ("pytest" , test , skip_tests )
231
233
234
+
232
235
def run_fx_lower_tests (session ):
233
236
print ("Running FX passes and trt_lower tests" )
234
237
session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
@@ -237,7 +240,7 @@ def run_fx_lower_tests(session):
237
240
# "passes/test_fuse_permute_linear_trt.py",
238
241
"passes/test_remove_duplicate_output_args.py" ,
239
242
"passes/test_fuse_permute_matmul_trt.py" ,
240
- #"passes/test_graph_opts.py"
243
+ # "passes/test_graph_opts.py"
241
244
"trt_lower" ,
242
245
]
243
246
for test in tests :
@@ -246,6 +249,7 @@ def run_fx_lower_tests(session):
246
249
else :
247
250
session .run_always ("pytest" , test )
248
251
252
+
249
253
def run_fx_quant_tests (session ):
250
254
print ("Running FX Quant tests" )
251
255
session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
@@ -261,6 +265,7 @@ def run_fx_quant_tests(session):
261
265
else :
262
266
session .run_always ("pytest" , test , skip_tests )
263
267
268
+
264
269
def run_fx_tracer_tests (session ):
265
270
print ("Running FX Tracer tests" )
266
271
session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
@@ -269,14 +274,15 @@ def run_fx_tracer_tests(session):
269
274
tests = [
270
275
"tracer/test_acc_shape_prop.py" ,
271
276
"tracer/test_acc_tracer.py" ,
272
- #"tracer/test_dispatch_tracer.py"
277
+ # "tracer/test_dispatch_tracer.py"
273
278
]
274
279
for test in tests :
275
280
if USE_HOST_DEPS :
276
281
session .run_always ("pytest" , test , env = {"PYTHONPATH" : PYT_PATH })
277
282
else :
278
283
session .run_always ("pytest" , test )
279
284
285
+
280
286
def run_fx_tools_tests (session ):
281
287
print ("Running FX tools tests" )
282
288
session .chdir (os .path .join (TOP_DIR , "py/torch_tensorrt/fx/test" ))
@@ -396,6 +402,7 @@ def run_l0_api_tests(session):
396
402
run_base_tests (session )
397
403
cleanup (session )
398
404
405
+
399
406
def run_l0_fx_tests (session ):
400
407
if not USE_HOST_DEPS :
401
408
install_deps (session )
@@ -405,27 +412,31 @@ def run_l0_fx_tests(session):
405
412
run_fx_lower_tests (session )
406
413
cleanup (session )
407
414
415
+
408
416
def run_l0_fx_core_tests (session ):
409
417
if not USE_HOST_DEPS :
410
418
install_deps (session )
411
419
install_torch_trt (session )
412
420
run_fx_core_tests (session )
413
421
cleanup (session )
414
422
423
+
415
424
def run_l0_fx_converter_tests (session ):
416
425
if not USE_HOST_DEPS :
417
426
install_deps (session )
418
427
install_torch_trt (session )
419
428
run_fx_converter_tests (session )
420
429
cleanup (session )
421
430
431
+
422
432
def run_l0_fx_lower_tests (session ):
423
433
if not USE_HOST_DEPS :
424
434
install_deps (session )
425
435
install_torch_trt (session )
426
436
run_fx_lower_tests (session )
427
437
cleanup (session )
428
438
439
+
429
440
def run_l0_dla_tests (session ):
430
441
if not USE_HOST_DEPS :
431
442
install_deps (session )
@@ -443,6 +454,7 @@ def run_l1_model_tests(session):
443
454
run_model_tests (session )
444
455
cleanup (session )
445
456
457
+
446
458
def run_l1_int8_accuracy_tests (session ):
447
459
if not USE_HOST_DEPS :
448
460
install_deps (session )
@@ -452,6 +464,7 @@ def run_l1_int8_accuracy_tests(session):
452
464
run_int8_accuracy_tests (session )
453
465
cleanup (session )
454
466
467
+
455
468
def run_l1_fx_tests (session ):
456
469
if not USE_HOST_DEPS :
457
470
install_deps (session )
@@ -461,6 +474,7 @@ def run_l1_fx_tests(session):
461
474
run_fx_tools_tests (session )
462
475
cleanup (session )
463
476
477
+
464
478
def run_l2_trt_compatibility_tests (session ):
465
479
if not USE_HOST_DEPS :
466
480
install_deps (session )
@@ -483,26 +497,31 @@ def l0_api_tests(session):
483
497
"""When a developer needs to check correctness for a PR or something"""
484
498
run_l0_api_tests (session )
485
499
500
+
486
501
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
487
502
def l0_fx_tests (session ):
488
503
"""When a developer needs to check correctness for a PR or something"""
489
504
run_l0_fx_tests (session )
490
505
506
+
491
507
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
492
508
def l0_fx_core_tests (session ):
493
509
"""When a developer needs to check correctness for a PR or something"""
494
510
run_l0_fx_core_tests (session )
495
511
512
+
496
513
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
497
514
def l0_fx_converter_tests (session ):
498
515
"""When a developer needs to check correctness for a PR or something"""
499
516
run_l0_fx_converter_tests (session )
500
517
518
+
501
519
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
502
520
def l0_fx_lower_tests (session ):
503
521
"""When a developer needs to check correctness for a PR or something"""
504
522
run_l0_fx_lower_tests (session )
505
523
524
+
506
525
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
507
526
def l0_dla_tests (session ):
508
527
"""When a developer needs to check basic api functionality using host dependencies"""
@@ -514,11 +533,13 @@ def l1_model_tests(session):
514
533
"""When a user needs to test the functionality of standard models compilation and results"""
515
534
run_l1_model_tests (session )
516
535
536
+
517
537
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
518
538
def l1_fx_tests (session ):
519
539
"""When a user needs to test the functionality of standard models compilation and results"""
520
540
run_l1_fx_tests (session )
521
541
542
+
522
543
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
523
544
def l1_int8_accuracy_tests (session ):
524
545
"""Checking accuracy performance on various usecases"""
@@ -534,4 +555,4 @@ def l2_trt_compatibility_tests(session):
534
555
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
535
556
def l2_multi_gpu_tests (session ):
536
557
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
537
- run_l2_multi_gpu_tests (session )
558
+ run_l2_multi_gpu_tests (session )
0 commit comments