3
3
import torch
4
4
import torchvision .models as models
5
5
6
- from model_test_case import ModelTestCase
6
+ from multi_gpu_test_case import MultiGpuTestCase
7
7
8
+ gpu_id = 1
8
9
class TestCompile (MultiGpuTestCase ):
9
10
10
11
def setUp (self ):
11
- if not torch .cuda .device_count () > 1 :
12
- raise ValueError ("This test case is applicable for multi-gpu configurations only" )
13
-
14
- self .gpu_id = 1
15
- # Setting it up here so that all CUDA allocations are done on correct device
16
- trtorch .set_device (self .gpu_id )
17
12
self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
18
13
self .traced_model = torch .jit .trace (self .model , [self .input ])
19
14
self .scripted_model = torch .jit .script (self .model )
@@ -23,7 +18,7 @@ def test_compile_traced(self):
23
18
"input_shapes" : [self .input .shape ],
24
19
"device" : {
25
20
"device_type" : trtorch .DeviceType .GPU ,
26
- "gpu_id" : self . gpu_id ,
21
+ "gpu_id" : gpu_id ,
27
22
"dla_core" : 0 ,
28
23
"allow_gpu_fallback" : False ,
29
24
"disable_tf32" : False
@@ -39,7 +34,7 @@ def test_compile_script(self):
39
34
"input_shapes" : [self .input .shape ],
40
35
"device" : {
41
36
"device_type" : trtorch .DeviceType .GPU ,
42
- "gpu_id" : self . gpu_id ,
37
+ "gpu_id" : gpu_id ,
43
38
"dla_core" : 0 ,
44
39
"allow_gpu_fallback" : False ,
45
40
"disable_tf32" : False
@@ -58,6 +53,11 @@ def test_suite():
58
53
59
54
return suite
60
55
56
+ if not torch .cuda .device_count () > 1 :
57
+ raise ValueError ("This test case is applicable for multi-gpu configurations only" )
58
+
59
+ # Setting it up here so that all CUDA allocations are done on correct device
60
+ trtorch .set_device (gpu_id )
61
61
suite = test_suite ()
62
62
63
63
runner = unittest .TextTestRunner ()
0 commit comments