Skip to content

Commit f49604d

Browse files
authored
Merge pull request #677 from NVIDIA/model_tracing
fix: Fix python API tests for mobilenet v2
2 parents 158cffa + e5a38ff commit f49604d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/py/test_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def test_from_torch_tensor(self):
5151
"enabled_precisions": {torch.float}
5252
}
5353

54-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
55-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
54+
trt_mod = trtorch.compile(self.traced_model, compile_spec)
55+
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
5656
self.assertTrue(same < 2e-2)
5757

5858
def test_device(self):
5959
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}
6060

61-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
62-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
61+
trt_mod = trtorch.compile(self.traced_model, compile_spec)
62+
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
6363
self.assertTrue(same < 2e-2)
6464

6565

@@ -169,7 +169,7 @@ class TestPTtoTRTtoPT(ModelTestCase):
169169

170170
def setUp(self):
171171
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
172-
self.ts_model = torch.jit.script(self.model)
172+
self.ts_model = torch.jit.trace(self.model, [self.input])
173173

174174
def test_pt_to_trt_to_pt(self):
175175
compile_spec = {

0 commit comments

Comments
 (0)