@@ -29,7 +29,10 @@ def test_compile_traced(self):
29
29
self .assertTrue (same < 2e-2 )
30
30
31
31
def test_compile_script (self ):
32
- trt_mod = trtorch .compile (self .scripted_model , inputs = [self .input ], device = trtorch .Device (gpu_id = 0 ), enabled_precisions = {torch .float })
32
+ trt_mod = trtorch .compile (self .scripted_model ,
33
+ inputs = [self .input ],
34
+ device = trtorch .Device (gpu_id = 0 ),
35
+ enabled_precisions = {torch .float })
33
36
same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
34
37
self .assertTrue (same < 2e-2 )
35
38
@@ -245,8 +248,7 @@ def test_input_use_default_fp16_without_fp16_enabled(self):
245
248
half_mod = torch .jit .script (self .model )
246
249
half_mod .half ()
247
250
248
- trt_mod = trtorch .compile (half_mod ,
249
- inputs = [trtorch .Input (self .input .shape )])
251
+ trt_mod = trtorch .compile (half_mod , inputs = [trtorch .Input (self .input .shape )])
250
252
trt_mod (self .input .half ())
251
253
252
254
def test_input_respect_user_setting_fp16_weights_fp32_in (self ):
@@ -358,8 +360,12 @@ def test_suite():
358
360
suite .addTest (TestCompileHalf .parametrize (TestCompileHalf , model = models .resnet18 (pretrained = True )))
359
361
suite .addTest (TestCompileHalfDefault .parametrize (TestCompileHalfDefault , model = models .resnet18 (pretrained = True )))
360
362
suite .addTest (TestPTtoTRTtoPT .parametrize (TestPTtoTRTtoPT , model = models .mobilenet_v2 (pretrained = True )))
361
- suite .addTest (TestInputTypeDefaultsFP32Model .parametrize (TestInputTypeDefaultsFP32Model , model = models .resnet18 (pretrained = True )))
362
- suite .addTest (TestInputTypeDefaultsFP16Model .parametrize (TestInputTypeDefaultsFP16Model , model = models .resnet18 (pretrained = True )))
363
+ suite .addTest (
364
+ TestInputTypeDefaultsFP32Model .parametrize (TestInputTypeDefaultsFP32Model ,
365
+ model = models .resnet18 (pretrained = True )))
366
+ suite .addTest (
367
+ TestInputTypeDefaultsFP16Model .parametrize (TestInputTypeDefaultsFP16Model ,
368
+ model = models .resnet18 (pretrained = True )))
363
369
suite .addTest (TestFallbackToTorch .parametrize (TestFallbackToTorch , model = models .resnet18 (pretrained = True )))
364
370
suite .addTest (
365
371
TestModuleFallbackToTorch .parametrize (TestModuleFallbackToTorch , model = models .resnet18 (pretrained = True )))
0 commit comments