55import pytest
66import torch
77
8- import tests_pytorch .helpers .pipelines as tpipes
8+ import tests_pytorch .helpers .pipelines as pipes
99from lightning .pytorch .demos .boring_classes import BoringModel
1010from tests_pytorch .helpers .runif import RunIf
1111
1212
13- @RunIf (tensorrt = True , min_cuda_gpus = 1 )
13+ @RunIf (tensorrt = True , min_cuda_gpus = 1 , min_torch = "2.2.0" )
1414def test_tensorrt_saves_with_input_sample (tmp_path ):
1515 model = BoringModel ()
1616 ori_device = model .device
@@ -34,6 +34,7 @@ def test_tensorrt_saves_with_input_sample(tmp_path):
3434 assert len (file_path .getvalue ()) > 4e2
3535
3636
37+ @RunIf (tensorrt = True , min_cuda_gpus = 1 , min_torch = "2.2.0" )
3738def test_tensorrt_error_if_no_input (tmp_path ):
3839 model = BoringModel ()
3940 model .example_input_array = None
@@ -47,7 +48,7 @@ def test_tensorrt_error_if_no_input(tmp_path):
4748 model .to_tensorrt (file_path )
4849
4950
50- @RunIf (tensorrt = True , min_cuda_gpus = 2 )
51+ @RunIf (tensorrt = True , min_cuda_gpus = 2 , min_torch = "2.2.0" )
5152def test_tensorrt_saves_on_multi_gpu (tmp_path ):
5253 trainer_options = {
5354 "default_root_dir" : tmp_path ,
@@ -63,7 +64,7 @@ def test_tensorrt_saves_on_multi_gpu(tmp_path):
6364 model = BoringModel ()
6465 model .example_input_array = torch .randn ((4 , 32 ))
6566
66- tpipes .run_model_test (trainer_options , model , min_acc = 0.08 )
67+ pipes .run_model_test (trainer_options , model , min_acc = 0.08 )
6768
6869 file_path = os .path .join (tmp_path , "model.trt" )
6970 model .to_tensorrt (file_path )
@@ -79,7 +80,7 @@ def test_tensorrt_saves_on_multi_gpu(tmp_path):
7980 ("ts" , torch .jit .ScriptModule ),
8081 ],
8182)
82- @RunIf (tensorrt = True , min_cuda_gpus = 1 )
83+ @RunIf (tensorrt = True , min_cuda_gpus = 1 , min_torch = "2.2.0" )
8384def test_tensorrt_save_ir_type (ir , export_type ):
8485 model = BoringModel ()
8586 model .example_input_array = torch .randn ((4 , 32 ))
@@ -96,7 +97,7 @@ def test_tensorrt_save_ir_type(ir, export_type):
9697 "ir" ,
9798 ["default" , "dynamo" , "ts" ],
9899)
99- @RunIf (tensorrt = True , min_cuda_gpus = 1 )
100+ @RunIf (tensorrt = True , min_cuda_gpus = 1 , min_torch = "2.2.0" )
100101def test_tensorrt_export_reload (output_format , ir , tmp_path ):
101102 import torch_tensorrt
102103
0 commit comments