@@ -232,7 +232,7 @@ def test_dynamic_shape(self):
232232)
233233class TestTorchTensorRTModule (unittest .TestCase ):
234234 @staticmethod
235- def _get_trt_mod ():
235+ def _get_trt_mod (via_ts : bool = False ):
236236 class Test (torch .nn .Module ):
237237 def __init__ (self ):
238238 super (Test , self ).__init__ ()
@@ -244,9 +244,14 @@ def forward(self, x):
244244 return out
245245
246246 mod = torch .jit .script (Test ())
247- test_mod_engine_str = torchtrt .ts .convert_method_to_trt_engine (
248- mod , "forward" , inputs = [torchtrt .Input ((2 , 10 ))]
249- )
247+ if via_ts :
248+ test_mod_engine_str = torchtrt .ts .convert_method_to_trt_engine (
249+ mod , "forward" , inputs = [torchtrt .Input ((2 , 10 ))]
250+ )
251+ else :
252+ test_mod_engine_str = torchtrt .convert_method_to_trt_engine (
253+ mod , "forward" , inputs = [torchtrt .Input ((2 , 10 ))]
254+ )
250255 return TorchTensorRTModule (
251256 name = "test" ,
252257 serialized_engine = test_mod_engine_str ,
@@ -301,9 +306,12 @@ def forward(self, x):
301306 )
302307
303308 def test_set_get_profile_path_prefix (self ):
304- trt_mod = TestTorchTensorRTModule ._get_trt_mod ()
305- trt_mod .engine .profile_path_prefix = "/tmp/"
306- self .assertTrue (trt_mod .engine .profile_path_prefix == "/tmp/" )
309+ for trt_mod in (
310+ TestTorchTensorRTModule ._get_trt_mod (),
311+ TestTorchTensorRTModule ._get_trt_mod (via_ts = True ),
312+ ):
313+ trt_mod .engine .profile_path_prefix = "/tmp/"
314+ self .assertTrue (trt_mod .engine .profile_path_prefix == "/tmp/" )
307315
308316 def test_get_layer_info (self ):
309317 """
@@ -321,11 +329,14 @@ def test_get_layer_info(self):
321329
322330 import json
323331
324- trt_mod = TestTorchTensorRTModule ._get_trt_mod ()
325- trt_json = json .loads (trt_mod .get_layer_info ())
326- [self .assertTrue (k in trt_json .keys ()) for k in ["Layers" , "Bindings" ]]
327- self .assertTrue (len (trt_json ["Layers" ]) == 2 )
328- self .assertTrue (len (trt_json ["Bindings" ]) == 2 )
332+ for trt_mod in (
333+ TestTorchTensorRTModule ._get_trt_mod (),
334+ TestTorchTensorRTModule ._get_trt_mod (via_ts = True ),
335+ ):
336+ trt_json = json .loads (trt_mod .get_layer_info ())
337+ [self .assertTrue (k in trt_json .keys ()) for k in ["Layers" , "Bindings" ]]
338+ self .assertTrue (len (trt_json ["Layers" ]) == 2 )
339+ self .assertTrue (len (trt_json ["Bindings" ]) == 2 )
329340
330341
331342if __name__ == "__main__" :
0 commit comments