Skip to content

Commit 10c5fef

Browse files
frank-weiWei Wei
andauthored
[FX] Changes done internally at Facebook (#1625)
Co-authored-by: Wei Wei <[email protected]>
1 parent a343650 commit 10c5fef

File tree

14 files changed

+384
-220
lines changed

14 files changed

+384
-220
lines changed

.circleci/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ commands:
263263
parameters:
264264
torch-build:
265265
type: string
266-
default: "2.0.0.dev20230120+cu117"
266+
default: "2.0.0.dev20230129+cu117"
267267
torch-build-index:
268268
type: string
269269
default: "https://download.pytorch.org/whl/nightly/cu117"
@@ -1026,7 +1026,7 @@ parameters:
10261026
# Nightly platform config
10271027
torch-build:
10281028
type: string
1029-
default: "2.0.0.dev20230120+cu117"
1029+
default: "2.0.0.dev20230129+cu117"
10301030
torch-build-index:
10311031
type: string
10321032
default: "https://download.pytorch.org/whl/nightly/cu117"

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 128 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,6 @@ def aten_ops_sub(
298298
return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name)
299299

300300

301-
@tensorrt_converter(torch.ops.aten._unsafe_view.default)
302-
@tensorrt_converter(torch.ops.aten._reshape_alias.default)
303301
@tensorrt_converter(torch.ops.aten.view.default)
304302
def aten_ops_reshape(
305303
network: TRTNetwork,
@@ -308,11 +306,33 @@ def aten_ops_reshape(
308306
kwargs: Dict[str, Argument],
309307
name: str,
310308
) -> Union[TRTTensor, Sequence[TRTTensor]]:
311-
kwargs_new = {
312-
"input": args[0],
313-
"acc_out_ty": acc_utils.build_raw_tensor_meta(shape=args[1]),
314-
}
315-
return acc_ops_converters.acc_ops_reshape(network, target, None, kwargs_new, name)
309+
input_val = args[0]
310+
# for case where input_val is TRTensor
311+
input_val = get_trt_tensor(network, input_val, f"{name}_input_val")
312+
shape = args[1]
313+
314+
layer = network.add_shuffle(input_val)
315+
316+
if all(isinstance(s, int) for s in shape):
317+
layer.reshape_dims = tuple(shape)
318+
else:
319+
# Convert all the dimensions to trt Tensors.
320+
trt_shape = []
321+
322+
for i, s in enumerate(shape):
323+
if isinstance(s, TRTTensor):
324+
trt_shape.append(s)
325+
else:
326+
a = get_trt_tensor(network, s, f"{name}_{i}")
327+
trt_shape.append(a)
328+
329+
shape_layer = network.add_concatenation(inputs=trt_shape)
330+
shape_layer.axis = 0
331+
shape_layer.name = f"{name}_output_shape"
332+
layer.set_input(1, shape_layer.get_output(0))
333+
334+
set_layer_name(layer, target, name)
335+
return layer.get_output(0)
316336

317337

318338
@tensorrt_converter(torch.ops.aten.cat.default)
@@ -345,3 +365,104 @@ def aten_ops_expand(
345365
return acc_ops_converters.acc_ops_expand_tensor(
346366
network, target, None, kwargs_new, name
347367
)
368+
369+
370+
@tensorrt_converter(operator.floordiv)
371+
def aten_ops_operator_floordiv(
372+
network: TRTNetwork,
373+
target: Target,
374+
args: Tuple[Argument, ...],
375+
kwargs: Dict[str, Argument],
376+
name: str,
377+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
378+
kwargs_new = {
379+
"input": args[0],
380+
"other": args[1],
381+
}
382+
return acc_ops_converters.acc_ops_floor_div(network, target, None, kwargs_new, name)
383+
384+
385+
@tensorrt_converter(operator.mul)
386+
def aten_ops_operator_mul(
387+
network: TRTNetwork,
388+
target: Target,
389+
args: Tuple[Argument, ...],
390+
kwargs: Dict[str, Argument],
391+
name: str,
392+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
393+
kwargs_new = {
394+
"input": args[0],
395+
"other": args[1],
396+
}
397+
return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name)
398+
399+
400+
@tensorrt_converter(operator.add)
401+
def aten_ops_operator_add(
402+
network: TRTNetwork,
403+
target: Target,
404+
args: Tuple[Argument, ...],
405+
kwargs: Dict[str, Argument],
406+
name: str,
407+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
408+
kwargs_new = {
409+
"input": args[0],
410+
"other": args[1],
411+
}
412+
return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name)
413+
414+
415+
@tensorrt_converter(operator.sub)
416+
def aten_ops_operator_sub(
417+
network: TRTNetwork,
418+
target: Target,
419+
args: Tuple[Argument, ...],
420+
kwargs: Dict[str, Argument],
421+
name: str,
422+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
423+
kwargs_new = {
424+
"input": args[0],
425+
"other": args[1],
426+
}
427+
return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name)
428+
429+
430+
@tensorrt_converter(torch.ops.aten.sym_numel)
431+
def aten_ops_sym_numel(
432+
network: TRTNetwork,
433+
target: Target,
434+
args: Tuple[Argument, ...],
435+
kwargs: Dict[str, Argument],
436+
name: str,
437+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
438+
shape_layer = network.add_shape(args[0])
439+
set_layer_name(shape_layer, target, "_shape_layer")
440+
reduce_layer = network.add_reduce(
441+
shape_layer.get_output(0),
442+
trt.ReduceOperation.PROD,
443+
axes=get_axes_for_reduce_op(0, False),
444+
keep_dims=True,
445+
)
446+
set_layer_name(reduce_layer, target, "_reduce_layer")
447+
return reduce_layer.get_output(0)
448+
449+
450+
@tensorrt_converter(torch.ops.aten.sym_size)
451+
def aten_ops_sym_size(
452+
network: TRTNetwork,
453+
target: Target,
454+
args: Tuple[Argument, ...],
455+
kwargs: Dict[str, Argument],
456+
name: str,
457+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
458+
shape_layer = network.add_shape(args[0])
459+
ind = args[1]
460+
set_layer_name(shape_layer, target, "_shape_layer")
461+
slice_layer = network.add_slice(
462+
input=shape_layer.get_output(0),
463+
start=[ind],
464+
shape=[1],
465+
stride=[1],
466+
)
467+
set_layer_name(slice_layer, target, "_slice_layer")
468+
return slice_layer.get_output(0)

py/torch_tensorrt/fx/converters/convolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# @manual=//deeplearning/trt/python:py_tensorrt
2+
import logging
3+
24
import numpy as np
35
import tensorrt as trt
46
import torch
5-
import logging
67

78
from ..converter_registry import tensorrt_converter
89

py/torch_tensorrt/fx/lower.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.fx as fx
99
import torch.nn as nn
10+
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
1011
from torch.fx.passes.splitter_base import SplitResult
1112

1213
from .fx2trt import TRTInterpreter, TRTInterpreterResult
@@ -18,8 +19,7 @@
1819

1920
from .tracer.acc_tracer import acc_tracer
2021
from .trt_module import TRTModule
21-
from .utils import LowerPrecision, proxytensor_trace
22-
22+
from .utils import LowerPrecision
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -259,7 +259,9 @@ def create(
259259
return cls(
260260
lower_pass_manager_builder=LowerPassManagerBuilder(
261261
lower_setting=lower_setting,
262-
trace_func=lambda module, inputs: proxytensor_trace(module, inputs),
262+
trace_func=lambda module, inputs: aten_tracer.opt_trace(
263+
module, inputs
264+
),
263265
split_func=split_func,
264266
lower_func=default_lower_pass(interpreter_builder),
265267
)
@@ -308,14 +310,6 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
308310
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
309311
inputs, additional_inputs
310312
)
311-
if lower_setting.is_aten:
312-
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(
313-
inputs, additional_inputs
314-
)
315-
else:
316-
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
317-
inputs, additional_inputs
318-
)
319313
lower_result = pm(module)
320314
return lower_result
321315

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,31 @@ def graph_optimization_pass(self) -> PassManager:
127127

128128
return PassManager.build_from_passlist(passes)
129129

130+
def graph_optimization_pass_aten(self) -> PassManager:
131+
passes = []
132+
133+
for p in self.lower_setting.customized_fuse_pass.passes:
134+
passes.append(wrapper(p, self._input))
135+
for p in self.lower_setting.lower_basic_fuse_pass.passes:
136+
passes.append(wrapper(p, self._input))
137+
# TODO fix this pass for aten graph
138+
# if (
139+
# hasattr(self.lower_setting, "lower_precision")
140+
# and self.lower_setting.lower_precision is LowerPrecision.FP16
141+
# ) or (
142+
# hasattr(self.lower_setting, "precision")
143+
# and self.lower_setting.precision is LowerPrecision.FP16
144+
# ):
145+
# passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input))
146+
147+
passes.append(
148+
inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
149+
)
150+
# TODO we most likely do not need it for aten
151+
# passes.append(fix_reshape_batch_dim)
152+
153+
return PassManager.build_from_passlist(passes)
154+
130155
def _split_pass(self) -> PassManager:
131156
passes = [
132157
partial(
@@ -259,8 +284,7 @@ def build_aten2trt_lower_pipeline(
259284
passes.append(
260285
wrapper(self._trace_func, self._input),
261286
)
262-
passes.append(self._default_replace_mutable_op_pass())
263-
passes.append(self.graph_optimization_pass())
287+
passes.append(self.graph_optimization_pass_aten())
264288
passes.append(self._split_pass())
265289
passes.append(self._trt_lower_pass())
266290

py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def forward(self, x):
7373
# param("ceil_mode", 1, ceil_mode=True),
7474
]
7575
)
76-
@unittest.skip("PT tracer issue")
76+
@unittest.skip("PT2 tracer issue")
7777
def test_max_pool3d(
7878
self,
7979
test_name,
@@ -95,6 +95,7 @@ def forward(self, x):
9595
inputs = [torch.randn(1, 3, 32, 32, 32)]
9696
self.run_test(TestModule(), inputs, expected_ops={})
9797

98+
@unittest.skip("PT2 tracer issue")
9899
def test_max_pool3d_with_dynamic_shape(self):
99100
class TestModule(torch.nn.Module):
100101
def __init__(self):
@@ -118,7 +119,7 @@ def forward(self, x):
118119
@parameterized.expand(
119120
[
120121
("default", 1),
121-
param("stride", 2, stride=()),
122+
# param("stride", 2, stride=()), #PT2 tracer issue
122123
]
123124
)
124125
def test_stride_none_max_pool2d(
@@ -147,7 +148,7 @@ def forward(self, x):
147148
param("stride", 2, stride=()),
148149
]
149150
)
150-
@unittest.skip("PT tracer issue")
151+
@unittest.skip("PT2 tracer issue")
151152
def test_stride_none_max_pool3d(
152153
self,
153154
test_name,
@@ -209,6 +210,7 @@ def forward(self, x):
209210
param("stride", 2, stride=()),
210211
]
211212
)
213+
@unittest.skip("PT2 tracer issue")
212214
def test_stride_none_max_pool3d_with_dynamic_shape(
213215
self,
214216
test_name,

0 commit comments

Comments
 (0)