5
5
import torch
6
6
from torch import fx
7
7
from enum import Enum
8
-
8
+ from torch_tensorrt import fx
9
9
10
10
class _IRType (Enum ):
11
11
"""Enum to set the minimum required logging level to print a message to stdout
@@ -43,13 +43,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
43
43
if module_is_tsable and ir_targets_torchscript :
44
44
return _IRType .ts
45
45
elif module_is_fxable and ir_targets_fx :
46
- if module_type == _ModuleType .fx :
47
- raise ValueError ("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" )
48
- elif ir_targets_fx :
49
- raise ValueError ("Preferred ir was set to \" fx\" which is currently not supported by Torch-TensorRT" )
50
- else :
51
- raise ValueError ("Torch-TensorRT currently does not support fx" )
52
- # return _IRType.fx
46
+ return _IRType .fx
53
47
else :
54
48
if ir == "default" :
55
49
# Options are listed in order of preference
@@ -114,7 +108,78 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
114
108
ts_mod = torch .jit .script (module )
115
109
return torch_tensorrt .ts .compile (ts_mod , inputs = inputs , enabled_precisions = enabled_precisions , ** kwargs )
116
110
elif target_ir == _IRType .fx :
117
- raise RuntimeError ("fx is currently not supported" )
111
+ from torch_tensorrt .fx .tracer .acc_tracer import acc_tracer
112
+ from torch_tensorrt .fx import InputTensorSpec
113
+ from torch_tensorrt .fx import TRTInterpreter
114
+ from torch_tensorrt .fx .passes .lower_basic_pass import transform_setitem
115
+ from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter
116
+ from torch_tensorrt .fx .tools .trt_splitter import TRTSplitterSetting
117
+ from torch_tensorrt .fx .trt_module import TRTModule
118
+ from torch_tensorrt .fx .utils import LowerPrecision
119
+ acc_model = acc_tracer .trace (module , inputs )
120
+
121
+ splitter_setting = TRTSplitterSetting ()
122
+ splitter_setting .use_implicit_batch_dim = False
123
+ splitter = TRTSplitter (acc_model , inputs , settings = splitter_setting )
124
+ splitter .node_support_preview ()
125
+ split_mod = splitter ()
126
+ num_piece = 0
127
+ for name , _ in split_mod .named_children ():
128
+ print (f"graph is split into { name } " )
129
+ num_piece += 1
130
+
131
+ # if the graph module is split into pieces larger than 8, we consider its perf
132
+ # is not good and fall back to non-TRT
133
+ if num_piece > 8 :
134
+ print (
135
+ f"The graph module is split into { num_piece } which is large than the \
136
+ threshold=8. Fall back to non-TRT module."
137
+ )
138
+ return None
139
+
140
+ if torch .float16 in enabled_precisions or torch .half in enabled_precisions :
141
+ precision = LowerPrecision .FP16
142
+ else :
143
+ precision = LowerPrecision .FP32
144
+
145
+ def get_submod_inputs (mod , submod , inputs ):
146
+ acc_inputs = None
147
+
148
+ def get_input (self , inputs ):
149
+ nonlocal acc_inputs
150
+ acc_inputs = inputs
151
+
152
+ handle = submod .register_forward_pre_hook (get_input )
153
+ mod (* inputs )
154
+ handle .remove ()
155
+ return acc_inputs
156
+
157
+ for name , _ in split_mod .named_children ():
158
+ if "_run_on_acc" in name :
159
+ submod = getattr (split_mod , name )
160
+ # Get submodule inputs for fx2trt
161
+ acc_inputs = get_submod_inputs (split_mod , submod , inputs )
162
+
163
+ # fx2trt replacement
164
+ interp = TRTInterpreter (
165
+ submod ,
166
+ InputTensorSpec .from_tensors (acc_inputs ),
167
+ explicit_batch_dimension = True ,
168
+ )
169
+ r = interp .run (
170
+ max_workspace_size = 20 << 30 ,
171
+ lower_precision = precision ,
172
+ # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
173
+ )
174
+ # For profile
175
+ # from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
176
+ # profile_trt_module("", trt_mod, acc_inputs)
177
+ trt_mod = TRTModule (* r )
178
+
179
+ setattr (split_mod , name , trt_mod )
180
+ else :
181
+ submod = getattr (split_mod , name )
182
+ return split_mod
118
183
else :
119
184
raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
120
185
@@ -173,4 +238,4 @@ def convert_method_to_trt_engine(module: Any,
173
238
elif target_ir == _IRType .fx :
174
239
raise RuntimeError ("fx is currently not supported" )
175
240
else :
176
- raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
241
+ raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
0 commit comments