Skip to content

Commit 5335fab

Browse files
author
liangtao07
committed
Resolve conversation.
Move _parse_module_type out _module_ir. Change _module_ir to _get_target_ir. Signed-off-by: liangtao07 <[email protected]>
1 parent 9f971ed commit 5335fab

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Tuple, Any
1+
from typing import List, Dict, Any
22
from torch_tensorrt import _enums
33
import torch_tensorrt.ts
44
from torch_tensorrt import logging
@@ -33,16 +33,15 @@ def _parse_module_type(module: Any) -> _ModuleType:
3333
raise RuntimeError("Module is an unknown format")
3434

3535

36-
def _module_ir(module: Any, ir: str) -> Tuple[_ModuleType, _IRType]:
37-
module_type = _parse_module_type(module)
36+
def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
3837
module_is_tsable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.ts]])
3938
module_is_fxable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.fx]])
4039

4140
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
4241
ir_targets_fx = ir == "fx"
4342

4443
if module_is_tsable and ir_targets_torchscript:
45-
return module_type, _IRType.ts
44+
return _IRType.ts
4645
elif module_is_fxable and ir_targets_fx:
4746
if module_type == _ModuleType.fx:
4847
raise ValueError("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT")
@@ -56,7 +55,7 @@ def _module_ir(module: Any, ir: str) -> Tuple[_ModuleType, _IRType]:
5655
# Options are listed in order of preference
5756
if module_is_tsable:
5857
logging.log(logging.Level.Info, "ir was set to default, using TorchScript as ir")
59-
return module_type, _IRType.ts
58+
return _IRType.ts
6059
elif module_is_fxable:
6160
raise ValueError("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT")
6261
#logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx")
@@ -103,7 +102,8 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
103102
Returns:
104103
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
105104
"""
106-
module_type, target_ir = _module_ir(module, ir)
105+
module_type = _parse_module_type(module)
106+
target_ir = _get_target_ir(module_type, ir)
107107
if target_ir == _IRType.ts:
108108
ts_mod = module
109109
if module_type == _ModuleType.nn:
@@ -152,11 +152,11 @@ def convert_method_to_trt_engine(module: Any,
152152
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
153153
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
154154
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
155-
156155
Returns:
157156
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
158157
"""
159-
target_ir = _module_ir(module, ir)
158+
module_type = _parse_module_type(module)
159+
target_ir = _get_target_ir(module_type, ir)
160160
if target_ir == _IRType.ts:
161161
ts_mod = module
162162
if module_type == _ModuleType.nn:
@@ -172,5 +172,4 @@ def convert_method_to_trt_engine(module: Any,
172172
elif target_ir == _IRType.fx:
173173
raise RuntimeError("fx is currently not supported")
174174
else:
175-
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
176-
175+
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

0 commit comments

Comments
 (0)