Skip to content

Commit 2840531

Browse files
fix issue 3259 (#3260)
1 parent 6d40ff1 commit 2840531

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,15 @@ def convert_method_to_trt_engine(
343343
enabled_precisions if enabled_precisions is not None else {torch.float}
344344
)
345345

346+
if not arg_inputs and not inputs:
347+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
348+
349+
elif arg_inputs and inputs:
350+
raise AssertionError(
351+
"'arg_inputs' and 'inputs' should not be used at the same time."
352+
)
353+
arg_inputs = arg_inputs or inputs
354+
346355
module_type = _parse_module_type(module)
347356
target_ir = _get_target_fe(module_type, ir)
348357
if target_ir == _IRType.ts:
@@ -366,15 +375,6 @@ def convert_method_to_trt_engine(
366375
)
367376
elif target_ir == _IRType.dynamo:
368377
# Prepare torch and torchtrt inputs
369-
if not arg_inputs and not inputs:
370-
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
371-
372-
elif arg_inputs and inputs:
373-
raise AssertionError(
374-
"'arg_inputs' and 'inputs' should not be used at the same time."
375-
)
376-
arg_inputs = arg_inputs or inputs
377-
378378
if kwarg_inputs is None:
379379
kwarg_inputs = {}
380380

tests/py/ts/api/test_classes.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_dynamic_shape(self):
232232
)
233233
class TestTorchTensorRTModule(unittest.TestCase):
234234
@staticmethod
235-
def _get_trt_mod():
235+
def _get_trt_mod(via_ts: bool = False):
236236
class Test(torch.nn.Module):
237237
def __init__(self):
238238
super(Test, self).__init__()
@@ -244,9 +244,14 @@ def forward(self, x):
244244
return out
245245

246246
mod = torch.jit.script(Test())
247-
test_mod_engine_str = torchtrt.ts.convert_method_to_trt_engine(
248-
mod, "forward", inputs=[torchtrt.Input((2, 10))]
249-
)
247+
if via_ts:
248+
test_mod_engine_str = torchtrt.ts.convert_method_to_trt_engine(
249+
mod, "forward", inputs=[torchtrt.Input((2, 10))]
250+
)
251+
else:
252+
test_mod_engine_str = torchtrt.convert_method_to_trt_engine(
253+
mod, "forward", inputs=[torchtrt.Input((2, 10))]
254+
)
250255
return TorchTensorRTModule(
251256
name="test",
252257
serialized_engine=test_mod_engine_str,
@@ -301,9 +306,12 @@ def forward(self, x):
301306
)
302307

303308
def test_set_get_profile_path_prefix(self):
304-
trt_mod = TestTorchTensorRTModule._get_trt_mod()
305-
trt_mod.engine.profile_path_prefix = "/tmp/"
306-
self.assertTrue(trt_mod.engine.profile_path_prefix == "/tmp/")
309+
for trt_mod in (
310+
TestTorchTensorRTModule._get_trt_mod(),
311+
TestTorchTensorRTModule._get_trt_mod(via_ts=True),
312+
):
313+
trt_mod.engine.profile_path_prefix = "/tmp/"
314+
self.assertTrue(trt_mod.engine.profile_path_prefix == "/tmp/")
307315

308316
def test_get_layer_info(self):
309317
"""
@@ -321,11 +329,14 @@ def test_get_layer_info(self):
321329

322330
import json
323331

324-
trt_mod = TestTorchTensorRTModule._get_trt_mod()
325-
trt_json = json.loads(trt_mod.get_layer_info())
326-
[self.assertTrue(k in trt_json.keys()) for k in ["Layers", "Bindings"]]
327-
self.assertTrue(len(trt_json["Layers"]) == 2)
328-
self.assertTrue(len(trt_json["Bindings"]) == 2)
332+
for trt_mod in (
333+
TestTorchTensorRTModule._get_trt_mod(),
334+
TestTorchTensorRTModule._get_trt_mod(via_ts=True),
335+
):
336+
trt_json = json.loads(trt_mod.get_layer_info())
337+
[self.assertTrue(k in trt_json.keys()) for k in ["Layers", "Bindings"]]
338+
self.assertTrue(len(trt_json["Layers"]) == 2)
339+
self.assertTrue(len(trt_json["Bindings"]) == 2)
329340

330341

331342
if __name__ == "__main__":

0 commit comments

Comments
 (0)