1
- from typing import List , Dict , Any
1
+ from typing import List , Dict , Tuple , Any
2
2
from torch_tensorrt import _enums
3
3
import torch_tensorrt .ts
4
4
from torch_tensorrt import logging
@@ -14,19 +14,37 @@ class _IRType(Enum):
14
14
fx = 1
15
15
16
16
17
- def _module_ir (module : Any , ir : str ) -> _IRType .ts :
18
- # Possible module types
19
- module_is_tsable = any (
20
- isinstance (module , t ) for t in [torch .nn .Module , torch .jit .ScriptModule , torch .jit .ScriptFunction ])
21
- module_is_fxable = any (isinstance (module , t ) for t in [torch .nn .Module , torch .fx .GraphModule ])
17
+ class _ModuleType (Enum ):
18
+ """Enum to set the minimum required logging level to print a message to stdout
19
+ """
20
+ nn = 0
21
+ ts = 1
22
+ fx = 2
23
+
24
+
25
+ def _parse_module_type (module : Any ) -> _ModuleType :
26
+ if any (isinstance (module , t ) for t in [torch .jit .ScriptModule , torch .jit .ScriptFunction ]):
27
+ return _ModuleType .ts
28
+ elif isinstance (module , torch .fx .GraphModule ):
29
+ return _ModuleType .fx
30
+ elif isinstance (module , torch .nn .Module ):
31
+ return _ModuleType .nn
32
+ else :
33
+ raise RuntimeError ("Module is an unknown format" )
34
+
35
+
36
+ def _module_ir (module : Any , ir : str ) -> Tuple [_ModuleType , _IRType ]:
37
+ module_type = _parse_module_type (module )
38
+ module_is_tsable = any ([module_type == t for t in [_ModuleType .nn , _ModuleType .ts ]])
39
+ module_is_fxable = any ([module_type == t for t in [_ModuleType .nn , _ModuleType .fx ]])
22
40
23
41
ir_targets_torchscript = any ([ir == opt for opt in ["torchscript" , "ts" ]])
24
42
ir_targets_fx = ir == "fx"
25
43
26
44
if module_is_tsable and ir_targets_torchscript :
27
- return _IRType .ts
45
+ return module_type , _IRType .ts
28
46
elif module_is_fxable and ir_targets_fx :
29
- if isinstance ( module , torch .fx . GraphModule ) :
47
+ if module_type == _ModuleType .fx :
30
48
raise ValueError ("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" )
31
49
elif ir_targets_fx :
32
50
raise ValueError ("Preferred ir was set to \" fx\" which is currently not supported by Torch-TensorRT" )
@@ -38,7 +56,7 @@ def _module_ir(module: Any, ir: str) -> _IRType.ts:
38
56
# Options are listed in order of preference
39
57
if module_is_tsable :
40
58
logging .log (logging .Level .Info , "ir was set to default, using TorchScript as ir" )
41
- return _IRType .ts
59
+ return module_type , _IRType .ts
42
60
elif module_is_fxable :
43
61
raise ValueError ("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" )
44
62
#logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx")
@@ -85,10 +103,10 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
85
103
Returns:
86
104
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
87
105
"""
88
- target_ir = _module_ir (module , ir )
106
+ module_type , target_ir = _module_ir (module , ir )
89
107
if target_ir == _IRType .ts :
90
108
ts_mod = module
91
- if isinstance ( module , torch .nn . Module ) :
109
+ if module_type == _ModuleType .nn :
92
110
logging .log (
93
111
logging .Level .Info ,
94
112
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
@@ -141,7 +159,7 @@ def convert_method_to_trt_engine(module: Any,
141
159
target_ir = _module_ir (module , ir )
142
160
if target_ir == _IRType .ts :
143
161
ts_mod = module
144
- if isinstance ( module , torch .nn . Module ) :
162
+ if module_type == _ModuleType .nn :
145
163
logging .log (
146
164
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
147
165
)
@@ -155,3 +173,4 @@ def convert_method_to_trt_engine(module: Any,
155
173
raise RuntimeError ("fx is currently not supported" )
156
174
else :
157
175
raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
176
+
0 commit comments