Skip to content

Commit a229e53

Browse files
author
Wei Wei
committed
[fx2trt] 1)restruct fx2trt codebase; 2)add fx2trt in compile() 3)add test_fx2trt.py example
1 parent e9fad34 commit a229e53

File tree

144 files changed

+26201
-10
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

144 files changed

+26201
-10
lines changed

py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch_tensorrt._Input import Input
1818
from torch_tensorrt._Device import Device
1919

20+
from torch_tensorrt import fx
2021

2122
def _register_with_torch():
2223
trtorch_dir = os.path.dirname(__file__)

py/torch_tensorrt/_compile.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch import fx
77
from enum import Enum
8-
8+
from torch_tensorrt import fx
99

1010
class _IRType(Enum):
1111
"""Enum to set the minimum required logging level to print a message to stdout
@@ -43,13 +43,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
4343
if module_is_tsable and ir_targets_torchscript:
4444
return _IRType.ts
4545
elif module_is_fxable and ir_targets_fx:
46-
if module_type == _ModuleType.fx:
47-
raise ValueError("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT")
48-
elif ir_targets_fx:
49-
raise ValueError("Preferred ir was set to \"fx\" which is currently not supported by Torch-TensorRT")
50-
else:
51-
raise ValueError("Torch-TensorRT currently does not support fx")
52-
# return _IRType.fx
46+
return _IRType.fx
5347
else:
5448
if ir == "default":
5549
# Options are listed in order of preference
@@ -114,7 +108,78 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
114108
ts_mod = torch.jit.script(module)
115109
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
116110
elif target_ir == _IRType.fx:
117-
raise RuntimeError("fx is currently not supported")
111+
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
112+
from torch_tensorrt.fx import InputTensorSpec
113+
from torch_tensorrt.fx import TRTInterpreter
114+
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
115+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
116+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting
117+
from torch_tensorrt.fx.trt_module import TRTModule
118+
from torch_tensorrt.fx.utils import LowerPrecision
119+
acc_model = acc_tracer.trace(module, inputs)
120+
121+
splitter_setting = TRTSplitterSetting()
122+
splitter_setting.use_implicit_batch_dim = False
123+
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
124+
splitter.node_support_preview()
125+
split_mod = splitter()
126+
num_piece = 0
127+
for name, _ in split_mod.named_children():
128+
print(f"graph is split into {name}")
129+
num_piece += 1
130+
131+
# if the graph module is split into pieces larger than 8, we consider its perf
132+
# is not good and fall back to non-TRT
133+
if num_piece > 8:
134+
print(
135+
f"The graph module is split into {num_piece} which is large than the \
136+
threshold=8. Fall back to non-TRT module."
137+
)
138+
return None
139+
140+
if torch.float16 in enabled_precisions or torch.half in enabled_precisions:
141+
precision = LowerPrecision.FP16
142+
else:
143+
precision = LowerPrecision.FP32
144+
145+
def get_submod_inputs(mod, submod, inputs):
146+
acc_inputs = None
147+
148+
def get_input(self, inputs):
149+
nonlocal acc_inputs
150+
acc_inputs = inputs
151+
152+
handle = submod.register_forward_pre_hook(get_input)
153+
mod(*inputs)
154+
handle.remove()
155+
return acc_inputs
156+
157+
for name, _ in split_mod.named_children():
158+
if "_run_on_acc" in name:
159+
submod = getattr(split_mod, name)
160+
# Get submodule inputs for fx2trt
161+
acc_inputs = get_submod_inputs(split_mod, submod, inputs)
162+
163+
# fx2trt replacement
164+
interp = TRTInterpreter(
165+
submod,
166+
InputTensorSpec.from_tensors(acc_inputs),
167+
explicit_batch_dimension=True,
168+
)
169+
r = interp.run(
170+
max_workspace_size=20 << 30,
171+
lower_precision=precision,
172+
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
173+
)
174+
# For profile
175+
# from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
176+
# profile_trt_module("", trt_mod, acc_inputs)
177+
trt_mod = TRTModule(*r)
178+
179+
setattr(split_mod, name, trt_mod)
180+
else:
181+
submod = getattr(split_mod, name)
182+
return split_mod
118183
else:
119184
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
120185

@@ -173,4 +238,4 @@ def convert_method_to_trt_engine(module: Any,
173238
elif target_ir == _IRType.fx:
174239
raise RuntimeError("fx is currently not supported")
175240
else:
176-
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
241+
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

py/torch_tensorrt/fx/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .converters import * # noqa: F403 F401
2+
from .converter_registry import ( # noqa
3+
CONVERTERS,
4+
NO_EXPLICIT_BATCH_DIM_SUPPORT,
5+
NO_IMPLICIT_BATCH_DIM_SUPPORT,
6+
tensorrt_converter,
7+
)
8+
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
9+
from .input_tensor_spec import InputTensorSpec # noqa
10+
from .trt_module import TRTModule # noqa
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Any, Callable, Dict
2+
3+
from torch.fx.node import Target
4+
5+
6+
CONVERTERS: Dict[Target, Any] = {}
7+
NO_IMPLICIT_BATCH_DIM_SUPPORT = {}
8+
NO_EXPLICIT_BATCH_DIM_SUPPORT = {}
9+
10+
11+
def tensorrt_converter(
12+
key: Target,
13+
no_implicit_batch_dim: bool = False,
14+
no_explicit_batch_dim: bool = False,
15+
enabled: bool = True,
16+
) -> Callable[[Any], Any]:
17+
def register_converter(converter):
18+
CONVERTERS[key] = converter
19+
if no_implicit_batch_dim:
20+
NO_IMPLICIT_BATCH_DIM_SUPPORT[key] = converter
21+
if no_explicit_batch_dim:
22+
NO_EXPLICIT_BATCH_DIM_SUPPORT[key] = converter
23+
return converter
24+
25+
def disable_converter(converter):
26+
return converter
27+
28+
if enabled:
29+
return register_converter
30+
else:
31+
return disable_converter
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# @manual=//deeplearning/trt/python:py_tensorrt
2+
import tensorrt as trt
3+
4+
if hasattr(trt, "__version__"):
5+
from .activation import * # noqa: F401 F403
6+
from .adaptive_avgpool import * # noqa: F401 F403
7+
from .add import * # noqa: F401 F403
8+
from .batchnorm import * # noqa: F401 F403
9+
from .convolution import * # noqa: F401 F403
10+
from .linear import * # noqa: F401 F403
11+
from .maxpool import * # noqa: F401 F403
12+
from .mul import * # noqa: F401 F403
13+
from .transformation import * # noqa: F401 F403
14+
from .quantization import * # noqa: F401 F403
15+
from .acc_ops_converters import * # noqa: F401 F403
16+
17+
TRT_LOGGER = trt.Logger()
18+
trt.init_libnvinfer_plugins(TRT_LOGGER, "")

0 commit comments

Comments
 (0)