Skip to content

Commit 53257f9

Browse files
committed
fix: fix unittest.
1 parent 958968d commit 53257f9

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/tests_pytorch/models/test_torch_tensorrt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def test_tensorrt_save_ir_type(ir, export_type):
9999
)
100100
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
101101
def test_tensorrt_export_reload(output_format, ir, tmp_path):
102-
import torch_tensorrt
103-
104102
if ir == "ts" and output_format == "exported_program":
105103
pytest.skip("TorchScript cannot be exported as exported_program")
106104

105+
import torch_tensorrt
106+
107107
model = BoringModel()
108108
model.cuda().eval()
109-
model.example_input_array = torch.randn((4, 32))
109+
model.example_input_array = torch.ones((4, 32))
110110

111111
file_path = os.path.join(tmp_path, "model.trt")
112112
model.to_tensorrt(file_path, output_format=output_format, ir=ir)
@@ -116,7 +116,7 @@ def test_tensorrt_export_reload(output_format, ir, tmp_path):
116116
loaded_model = loaded_model.module()
117117

118118
with torch.no_grad(), torch.inference_mode():
119-
model_output = model(model.example_input_array.to(model.device))
119+
model_output = model(model.example_input_array.to("cuda"))
120+
jit_output = loaded_model(model.example_input_array.to("cuda"))
120121

121-
jit_output = loaded_model(model.example_input_array.to("cuda"))
122-
assert torch.allclose(model_output, jit_output)
122+
assert torch.allclose(model_output, jit_output, rtol=1e-03, atol=1e-06)

0 commit comments

Comments
 (0)