Skip to content

Commit 9f971ed

Browse files
author
liangtao07
committed
Fix the bug of incorrect model type identification
Signed-off-by: liangtao07 <[email protected]>
1 parent 55c3bab commit 9f971ed

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Any
1+
from typing import List, Dict, Tuple, Any
22
from torch_tensorrt import _enums
33
import torch_tensorrt.ts
44
from torch_tensorrt import logging
@@ -14,19 +14,37 @@ class _IRType(Enum):
1414
fx = 1
1515

1616

17-
def _module_ir(module: Any, ir: str) -> _IRType.ts:
18-
# Possible module types
19-
module_is_tsable = any(
20-
isinstance(module, t) for t in [torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction])
21-
module_is_fxable = any(isinstance(module, t) for t in [torch.nn.Module, torch.fx.GraphModule])
17+
class _ModuleType(Enum):
18+
"""Enum to set the minimum required logging level to print a message to stdout
19+
"""
20+
nn = 0
21+
ts = 1
22+
fx = 2
23+
24+
25+
def _parse_module_type(module: Any) -> _ModuleType:
26+
if any(isinstance(module, t) for t in [torch.jit.ScriptModule, torch.jit.ScriptFunction]):
27+
return _ModuleType.ts
28+
elif isinstance(module, torch.fx.GraphModule):
29+
return _ModuleType.fx
30+
elif isinstance(module, torch.nn.Module):
31+
return _ModuleType.nn
32+
else:
33+
raise RuntimeError("Module is an unknown format")
34+
35+
36+
def _module_ir(module: Any, ir: str) -> Tuple[_ModuleType, _IRType]:
37+
module_type = _parse_module_type(module)
38+
module_is_tsable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.ts]])
39+
module_is_fxable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.fx]])
2240

2341
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
2442
ir_targets_fx = ir == "fx"
2543

2644
if module_is_tsable and ir_targets_torchscript:
27-
return _IRType.ts
45+
return module_type, _IRType.ts
2846
elif module_is_fxable and ir_targets_fx:
29-
if isinstance(module, torch.fx.GraphModule):
47+
if module_type == _ModuleType.fx:
3048
raise ValueError("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT")
3149
elif ir_targets_fx:
3250
raise ValueError("Preferred ir was set to \"fx\" which is currently not supported by Torch-TensorRT")
@@ -38,7 +56,7 @@ def _module_ir(module: Any, ir: str) -> _IRType.ts:
3856
# Options are listed in order of preference
3957
if module_is_tsable:
4058
logging.log(logging.Level.Info, "ir was set to default, using TorchScript as ir")
41-
return _IRType.ts
59+
return module_type, _IRType.ts
4260
elif module_is_fxable:
4361
raise ValueError("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT")
4462
#logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx")
@@ -85,10 +103,10 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
85103
Returns:
86104
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
87105
"""
88-
target_ir = _module_ir(module, ir)
106+
module_type, target_ir = _module_ir(module, ir)
89107
if target_ir == _IRType.ts:
90108
ts_mod = module
91-
if isinstance(module, torch.nn.Module):
109+
if module_type == _ModuleType.nn:
92110
logging.log(
93111
logging.Level.Info,
94112
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
@@ -141,7 +159,7 @@ def convert_method_to_trt_engine(module: Any,
141159
target_ir = _module_ir(module, ir)
142160
if target_ir == _IRType.ts:
143161
ts_mod = module
144-
if isinstance(module, torch.nn.Module):
162+
if module_type == _ModuleType.nn:
145163
logging.log(
146164
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
147165
)
@@ -155,3 +173,4 @@ def convert_method_to_trt_engine(module: Any,
155173
raise RuntimeError("fx is currently not supported")
156174
else:
157175
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
176+

0 commit comments

Comments
 (0)