Skip to content

Commit 2717771

Browse files
author
Yinghai Lu
authored
Changes done internally at Facebook (#1161)
91196dedd10ed4061bef719b01fe917192a6fc17 Shiyan Deng <[email protected]> fix EngineHolder 608acce060d91ac93099f59e6165d43ffd79d50a Shiyan Deng <[email protected]> delete unused file f88cb1566ae1c72634668fa26fadc3ef368b9727 Alex Beloi <[email protected]> [fx] add common_subexpression_elimination graph opt b079de89bce912c29284809722cae63cb185c481 Yinghai Lu <[email protected]> handle list of integers in SliceOp check e8f8dbda72625872a2657cdfd09dd59209f98090 Shiyan Deng <[email protected]> Add a pass to eliminate unsqueeze + cat + getitem pattern 675b28120c432c63bf4d451a4479b2376f73201d Kefei Lu <[email protected]> Back out "[fx] add common_subexpression_elimination graph opt" a3fb38bb6193b86a0ef0795a9104c073f5044598 Shiyan Deng <[email protected]> Back out "Back out "[fx] add common_subexpression_elimination graph opt"" 0ea2b5782d70ed07b2158a96a45082aba68c55f6 Andrew Or <[email protected]> [Quant][fx][bc-breaking] Replace is_reference with convert_to_reference 2e4384397e73e99a3d56303864dc09e23420dda1 Shreyansh Prajapati <[email protected]> Test dynamic shape suport for acc_ops.sigmoid 78f7505cbe18850f43261b5f92e67262127e229e Mor Tzur <[email protected]> uncomment accidentally commented out acc_op tests a7c0210a5c6468d978fc55a541b08da4e1782df7 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.avgpool and error message change 8aab78b134c778bcafe56f0fd5baa9bf51d7fef9 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.any 0591b671542c186a257152f8ed82aecbd2322feb Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.avgpool 5d36837498f72f5ad3432eecf3f4e4091ac23071 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.batchnorm 9a3134742872040b31a70bf216c81de5dcd82c45 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.binary_ops 4802100968d39418d6d1ce75c0407b3344e5cfd9 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.as_strided 6ccf77223a40e16d635d453b6342658bdc44062e Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.cat 0dd29ac935705926cc29bf03f5da922d93adae17 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.chunk 81737ebcecb0a57a683971d9052b3f3262030d20 Shreyansh Prajapati <[email protected]> Added test case for testing dynamic shape support for acc_ops.clamp 0c2604c6db94309c8a76de00e8c69ce21d2114a8 Yinghai Lu <[email protected]> [fx2trt] Fix constant tensor conversion dtype bug f0edd3db9b8f74103b536e1f2af81ce8558d2f88 Lu Fang <[email protected]> [Not for landing] Unblock ifr_unified_1_with_cover_dhen's lowering e610d365fa21f495d9b651e519ce8a407d3aa744 Lu Fang <[email protected]> [Not for landing] Unblock ifr_unified_1_with_cover_dhen's lowering eec55e9c3bab30dea9cde52feb9cc74f30a43035 Yinghai Lu <[email protected]> [fx2trt] Fix constant tensor conversion dtype bug 851ca556610bd396645a9822932939d12d6ed915 Yinghai Lu <[email protected]> [fx2trt] Fix constant tensor conversion dtype bug
1 parent e19cbfc commit 2717771

20 files changed

+691
-150
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ bazel build //:libtorchtrt --compilation_mode opt
213213
```
214214

215215
### FX path (Python only) installation
216-
If the user plan to try FX path (Python only) and would like to avoid bazel build. Please follow the steps below.
216+
If the user plans to try FX path (Python only) and would like to avoid bazel build. Please follow the steps below.
217217
``` shell
218218
cd py && python3 setup.py install --fx-only
219219
```

examples/fx/quantized_resnet_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
88
import torchvision.models as models
9-
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
9+
from torch.ao.quantization.quantize_fx import (
10+
convert_fx,
11+
convert_to_reference,
12+
prepare_fx,
13+
)
1014
from torch.fx.experimental.normalize import NormalizeArgs
1115
from torch.fx.passes import shape_prop
1216
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
@@ -48,7 +52,7 @@ def build_int8_trt(rn18):
4852
prepared = prepare_fx(rn18, {"": qconfig})
4953
for _ in range(10):
5054
prepared(data)
51-
quantized_rn18 = convert_fx(prepared, is_reference=True)
55+
quantized_rn18 = convert_to_reference(prepared)
5256
ref_res = quantized_rn18(data)
5357
print("quantized model:", quantized_rn18)
5458

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2222,7 +2222,7 @@ def acc_ops_adaptive_avg_poolnd(
22222222
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
22232223
assert all(
22242224
input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
2225-
), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims."
2225+
), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
22262226

22272227
output_size = cast(
22282228
Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len)

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,8 @@ def add_binary_elementwise_layer(
415415
This function adds a TensorRT elementwise layer. We allow both operands to be
416416
constant (not a trt tensor) because in implicit batch dimension mode, we could
417417
introduce constant via .size() op. Other scenario should be const folded first.
418-
If any operand is not a trt tensor, we make it a trt constant layer which has
419-
the same type as the other trt tensor. Then we broadcast these two inputs to
420-
have the same number of dimensions.
418+
If any operand is not a trt tensor, we make it a trt constant layer while preserve
419+
its dtype. Then we broadcast these two inputs to have the same number of dimensions.
421420
422421
Limitation:
423422
If we are using implicit batch dim mode, the operand that is not a trt
@@ -436,14 +435,16 @@ def add_binary_elementwise_layer(
436435
Returns:
437436
The output of TensorRT Elementwise layer.
438437
"""
439-
dtype = None
438+
lhs_dtype = None
439+
rhs_dtype = None
440440
is_lhs_trt_tensor = False
441441
is_rhs_trt_tensor = False
442+
442443
if isinstance(lhs_val, TRTTensor):
443-
dtype = torch_dtype_from_trt(lhs_val.dtype)
444+
lhs_dtype = torch_dtype_from_trt(lhs_val.dtype)
444445
is_lhs_trt_tensor = True
445446
if isinstance(rhs_val, TRTTensor):
446-
dtype = torch_dtype_from_trt(rhs_val.dtype)
447+
rhs_dtype = torch_dtype_from_trt(rhs_val.dtype)
447448
is_rhs_trt_tensor = True
448449

449450
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -463,10 +464,14 @@ def add_binary_elementwise_layer(
463464
# this way the shape will become [1], and then will be properly squeezed
464465
# into [], meaning that the result will have shape [], which is what we
465466
# expect.
467+
#
468+
# Note that the dtype here is supposed to be the same as the scalar
469+
# dtype but we don't have a way to detect whether it makes sense for the
470+
# scalar to be float or half. Hence we go with the lhs dtype.
466471
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
467-
rhs_val = torch.tensor([rhs_val], dtype=dtype)
472+
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
468473
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
469-
lhs_val = torch.tensor([lhs_val], dtype=dtype)
474+
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)
470475

471476
# When lhs is scalar, and rhs has shape [1,], then currently the assert
472477
# will fail because lhs shape has fewer dimensions than rhs shape. This
@@ -482,8 +487,8 @@ def add_binary_elementwise_layer(
482487
if isinstance(rhs_val, torch.Tensor):
483488
rhs_val = squeeze_left(rhs_val)
484489

485-
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", dtype)
486-
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", dtype)
490+
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
491+
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)
487492

488493
# Check the limitation in the doc string.
489494
if network.has_implicit_batch_dimension:

py/torch_tensorrt/fx/fx2trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
] = dict()
8181

8282
def validate_input_specs(self):
83-
for shape, dtpe, _, shape_ranges, has_batch_dim in self.input_specs:
83+
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
8484
if not self.network.has_implicit_batch_dimension:
8585
assert (
8686
has_batch_dim
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from collections.abc import Sequence
2+
3+
import torch
4+
import torch.fx
5+
6+
7+
def common_subexpression_elimination(graph_module: torch.fx.GraphModule) -> bool:
8+
"""
9+
Optimize quantization by removing repeated subexpressions.
10+
11+
Args:
12+
graph_module(torch.fx.GraphModule): target module to be optimized
13+
14+
Returns:
15+
Graph changed or not.
16+
"""
17+
18+
def seq_hashable(seq):
19+
if seq is None:
20+
return None
21+
22+
items = []
23+
for old in seq:
24+
if isinstance(old, Sequence) and not isinstance(old, str):
25+
new = seq_hashable(old)
26+
elif isinstance(old, dict):
27+
new = dict_hashable(old)
28+
elif isinstance(old, slice):
29+
new = old.__reduce__()
30+
else:
31+
new = old
32+
33+
items.append(new)
34+
35+
return tuple(items)
36+
37+
def dict_hashable(d):
38+
if d is None:
39+
return None
40+
41+
items = []
42+
for k, old_v in d.items():
43+
if isinstance(old_v, Sequence):
44+
new_v = seq_hashable(old_v)
45+
elif isinstance(old_v, dict):
46+
new_v = dict_hashable(old_v)
47+
elif isinstance(old_v, slice):
48+
new_v = old_v.__reduce__()
49+
else:
50+
new_v = old_v
51+
52+
items.append((k, new_v))
53+
return tuple(sorted(items))
54+
55+
changed = False
56+
env = {}
57+
for n in graph_module.graph.nodes:
58+
# do not CSE away impure ops
59+
if n.op not in {"call_function", "call_method"} or n.is_impure():
60+
continue
61+
62+
# hash target, args, kwargs
63+
hash_val = (n.target, seq_hashable(n.args), dict_hashable(n.kwargs))
64+
65+
# check if a node has a substitute and can be eliminated
66+
if hash_val in env:
67+
n.replace_all_uses_with(env[hash_val])
68+
graph_module.graph.erase_node(n)
69+
changed = True
70+
continue
71+
72+
env[hash_val] = n
73+
74+
return changed

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial, wraps
2-
from typing import Any, Callable, NamedTuple, Sequence
2+
from typing import Any, Callable, Sequence
33

44
import torch
55
from torch import nn
@@ -10,6 +10,7 @@
1010
from ..lower_setting import LowerSetting
1111
from ..observer import Observer
1212
from ..passes.remove_duplicate_output_args import remove_duplicate_output_args
13+
from .graph_opts import common_subexpression_elimination
1314

1415
from .lower_basic_pass import run_const_fold
1516

@@ -94,6 +95,8 @@ def graph_optimization_pass(self) -> PassManager:
9495
passes.append(wrapper(p, self._input))
9596
for p in self.lower_setting.lower_basic_fuse_pass.passes:
9697
passes.append(wrapper(p, self._input))
98+
99+
passes.append(inplace_wrapper(common_subexpression_elimination))
97100
passes.append(
98101
inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
99102
)

py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def forward(self, x):
9292
TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool3d}
9393
)
9494

95+
# Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
96+
9597

9698
if __name__ == "__main__":
9799
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_any.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torch.testing._internal.common_utils import run_tests
66
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
77

8+
# from torch_tensorrt.fx.tools.common_fx2trt import InputTensorSpec
9+
810

911
class TestAnyConverters(AccTestCase):
1012
@parameterized.expand(
@@ -64,6 +66,26 @@ def forward(self, x):
6466
test_implicit_batch_dim=False,
6567
)
6668

69+
# Testing with shape (-1, -1, -1, -1) results into error: torch.zeros(tuple([*input_t.shape])). Trying to create tensor with negative dimension -1: [-1, -1, -1, -1]
70+
"""
71+
def test_ops_with_dynamic_shape_four_dimensions(self):
72+
class TestModule(nn.Module):
73+
def forward(self, x):
74+
return torch.any(x)
75+
76+
input_specs = [
77+
InputTensorSpec(
78+
shape=(-1, -1, -1, -1),
79+
dtype=torch.float32,
80+
shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))],
81+
),
82+
]
83+
84+
self.run_test_with_dynamic_shape(
85+
TestModule(), input_specs, expected_ops={acc_ops.any}
86+
)
87+
"""
88+
6789

6890
if __name__ == "__main__":
6991
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
44
from parameterized import parameterized
55
from torch.testing._internal.common_utils import run_tests
6-
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
6+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
77

88

99
class TestConverter(AccTestCase):
@@ -30,6 +30,26 @@ def forward(self, x):
3030
test_implicit_batch_dim=False,
3131
)
3232

33+
# Testing with shape (-1, -1, -1, -1) results into error: RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 1
34+
"""
35+
def test_as_strided_with_dynamic_shape_four_dimensions(self):
36+
class Stride(nn.Module):
37+
def forward(self, x):
38+
return torch.as_strided(torch.tensor([5, 5]), (2, 3), (1, 2), 0)
39+
40+
input_specs = [
41+
InputTensorSpec(
42+
shape=(-1, -1, -1, -1),
43+
dtype=torch.float32,
44+
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
45+
),
46+
]
47+
48+
self.run_test_with_dynamic_shape(
49+
Stride(), input_specs, expected_ops={acc_ops.as_strided}
50+
)
51+
"""
52+
3353

3454
if __name__ == "__main__":
3555
run_tests()

0 commit comments

Comments
 (0)