File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -51,15 +51,15 @@ def test_from_torch_tensor(self):
51
51
"enabled_precisions" : {torch .float }
52
52
}
53
53
54
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
55
- same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
54
+ trt_mod = trtorch .compile (self .traced_model , compile_spec )
55
+ same = (trt_mod (self .input ) - self .traced_model (self .input )).abs ().max ()
56
56
self .assertTrue (same < 2e-2 )
57
57
58
58
def test_device (self ):
59
59
compile_spec = {"inputs" : [self .input ], "device" : trtorch .Device ("gpu:0" ), "enabled_precisions" : {torch .float }}
60
60
61
- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
62
- same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
61
+ trt_mod = trtorch .compile (self .traced_model , compile_spec )
62
+ same = (trt_mod (self .input ) - self .traced_model (self .input )).abs ().max ()
63
63
self .assertTrue (same < 2e-2 )
64
64
65
65
@@ -169,7 +169,7 @@ class TestPTtoTRTtoPT(ModelTestCase):
169
169
170
170
def setUp (self ):
171
171
self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
172
- self .ts_model = torch .jit .script (self .model )
172
+ self .ts_model = torch .jit .trace (self .model , [ self . input ] )
173
173
174
174
def test_pt_to_trt_to_pt (self ):
175
175
compile_spec = {
You can’t perform that action at this time.
0 commit comments