1
- from typing import List , Dict , Tuple , Any
1
+ from typing import List , Dict , Any
2
2
from torch_tensorrt import _enums
3
3
import torch_tensorrt .ts
4
4
from torch_tensorrt import logging
@@ -33,16 +33,15 @@ def _parse_module_type(module: Any) -> _ModuleType:
33
33
raise RuntimeError ("Module is an unknown format" )
34
34
35
35
36
- def _module_ir (module : Any , ir : str ) -> Tuple [_ModuleType , _IRType ]:
37
- module_type = _parse_module_type (module )
36
+ def _get_target_ir (module_type : _ModuleType , ir : str ) -> _IRType :
38
37
module_is_tsable = any ([module_type == t for t in [_ModuleType .nn , _ModuleType .ts ]])
39
38
module_is_fxable = any ([module_type == t for t in [_ModuleType .nn , _ModuleType .fx ]])
40
39
41
40
ir_targets_torchscript = any ([ir == opt for opt in ["torchscript" , "ts" ]])
42
41
ir_targets_fx = ir == "fx"
43
42
44
43
if module_is_tsable and ir_targets_torchscript :
45
- return module_type , _IRType .ts
44
+ return _IRType .ts
46
45
elif module_is_fxable and ir_targets_fx :
47
46
if module_type == _ModuleType .fx :
48
47
raise ValueError ("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" )
@@ -56,7 +55,7 @@ def _module_ir(module: Any, ir: str) -> Tuple[_ModuleType, _IRType]:
56
55
# Options are listed in order of preference
57
56
if module_is_tsable :
58
57
logging .log (logging .Level .Info , "ir was set to default, using TorchScript as ir" )
59
- return module_type , _IRType .ts
58
+ return _IRType .ts
60
59
elif module_is_fxable :
61
60
raise ValueError ("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" )
62
61
#logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx")
@@ -103,7 +102,8 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
103
102
Returns:
104
103
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
105
104
"""
106
- module_type , target_ir = _module_ir (module , ir )
105
+ module_type = _parse_module_type (module )
106
+ target_ir = _get_target_ir (module_type , ir )
107
107
if target_ir == _IRType .ts :
108
108
ts_mod = module
109
109
if module_type == _ModuleType .nn :
@@ -152,11 +152,11 @@ def convert_method_to_trt_engine(module: Any,
152
152
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
153
153
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
154
154
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
155
-
156
155
Returns:
157
156
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
158
157
"""
159
- target_ir = _module_ir (module , ir )
158
+ module_type = _parse_module_type (module )
159
+ target_ir = _get_target_ir (module_type , ir )
160
160
if target_ir == _IRType .ts :
161
161
ts_mod = module
162
162
if module_type == _ModuleType .nn :
@@ -172,5 +172,4 @@ def convert_method_to_trt_engine(module: Any,
172
172
elif target_ir == _IRType .fx :
173
173
raise RuntimeError ("fx is currently not supported" )
174
174
else :
175
- raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
176
-
175
+ raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
0 commit comments