Skip to content

Commit 236f1a0

Browse files
committed
test: add difference check in test_model_return_type.
1 parent bc81215 commit 236f1a0

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/tests_pytorch/models/test_onnx.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,5 +182,10 @@ def test_model_return_type():
182182
model.example_input_array = torch.randn((1, 32))
183183
model.eval()
184184

185-
ret = model.to_onnx(dynamo=True)
186-
assert isinstance(ret, torch.onnx.ONNXProgram)
185+
onnx_pg = model.to_onnx(dynamo=True)
186+
assert isinstance(onnx_pg, torch.onnx.ONNXProgram)
187+
188+
model_ret = model(model.example_input_array)
189+
inf_ret = onnx_pg(model.example_input_array)
190+
191+
assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05)

0 commit comments

Comments
 (0)