@@ -99,14 +99,14 @@ def test_tensorrt_save_ir_type(ir, export_type):
9999) 
100100@RunIf (tensorrt = True , min_cuda_gpus = 1 , min_torch = "2.2.0" ) 
101101def  test_tensorrt_export_reload (output_format , ir , tmp_path ):
102-     import  torch_tensorrt 
103- 
104102    if  ir  ==  "ts"  and  output_format  ==  "exported_program" :
105103        pytest .skip ("TorchScript cannot be exported as exported_program" )
106104
105+     import  torch_tensorrt 
106+ 
107107    model  =  BoringModel ()
108108    model .cuda ().eval ()
109-     model .example_input_array  =  torch .randn ((4 , 32 ))
109+     model .example_input_array  =  torch .ones ((4 , 32 ))
110110
111111    file_path  =  os .path .join (tmp_path , "model.trt" )
112112    model .to_tensorrt (file_path , output_format = output_format , ir = ir )
@@ -116,7 +116,7 @@ def test_tensorrt_export_reload(output_format, ir, tmp_path):
116116        loaded_model  =  loaded_model .module ()
117117
118118    with  torch .no_grad (), torch .inference_mode ():
119-         model_output  =  model (model .example_input_array .to (model .device ))
119+         model_output  =  model (model .example_input_array .to ("cuda" ))
120+         jit_output  =  loaded_model (model .example_input_array .to ("cuda" ))
120121
121-     jit_output  =  loaded_model (model .example_input_array .to ("cuda" ))
122-     assert  torch .allclose (model_output , jit_output )
122+     assert  torch .allclose (model_output , jit_output , rtol = 1e-03 , atol = 1e-06 )
0 commit comments