@@ -99,14 +99,14 @@ def test_tensorrt_save_ir_type(ir, export_type):
99
99
)
100
100
@RunIf (tensorrt = True , min_cuda_gpus = 1 , min_torch = "2.2.0" )
101
101
def test_tensorrt_export_reload (output_format , ir , tmp_path ):
102
- import torch_tensorrt
103
-
104
102
if ir == "ts" and output_format == "exported_program" :
105
103
pytest .skip ("TorchScript cannot be exported as exported_program" )
106
104
105
+ import torch_tensorrt
106
+
107
107
model = BoringModel ()
108
108
model .cuda ().eval ()
109
- model .example_input_array = torch .randn ((4 , 32 ))
109
+ model .example_input_array = torch .ones ((4 , 32 ))
110
110
111
111
file_path = os .path .join (tmp_path , "model.trt" )
112
112
model .to_tensorrt (file_path , output_format = output_format , ir = ir )
@@ -116,7 +116,7 @@ def test_tensorrt_export_reload(output_format, ir, tmp_path):
116
116
loaded_model = loaded_model .module ()
117
117
118
118
with torch .no_grad (), torch .inference_mode ():
119
- model_output = model (model .example_input_array .to (model .device ))
119
+ model_output = model (model .example_input_array .to ("cuda" ))
120
+ jit_output = loaded_model (model .example_input_array .to ("cuda" ))
120
121
121
- jit_output = loaded_model (model .example_input_array .to ("cuda" ))
122
- assert torch .allclose (model_output , jit_output )
122
+ assert torch .allclose (model_output , jit_output , rtol = 1e-03 , atol = 1e-06 )
0 commit comments