Skip to content

Commit d418a37

Browse files
Shirong WuWei Wei
authored andcommitted
Enable explicit batch dim support for getitem and chunk (#70)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/70 Enable explicit batch dim support(with limited support for dynamic shape) for acc_op getitem and chunk. Right now the converter can only process 1 dynamic shape dim. Reviewed By: frank-wei, 842974287 Differential Revision: D34454742 fbshipit-source-id: f1bf643ca94b268be7193d332a5819e6bc8d876d
1 parent 6142f97 commit d418a37

File tree

4 files changed

+205
-14
lines changed

4 files changed

+205
-14
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2752,9 +2752,15 @@ def acc_ops_getitem(
27522752
if not isinstance(input_val, TRTTensor):
27532753
return operator.getitem(input_val, slices) # type: ignore[arg-type]
27542754

2755-
assert not has_dynamic_shape(
2756-
input_val.shape
2757-
), "Currently we don't support slicing tensor if it has dynamic shape."
2755+
if not isinstance(slices, tuple) and not isinstance(slices, list):
2756+
slices = (slices,)
2757+
2758+
dynamic_shape = get_dynamic_dims(input_val.shape)
2759+
if dynamic_shape:
2760+
for i, s in zip(input_val.shape, slices):
2761+
assert i > 0 or (
2762+
s in [slice(None, None, None), slice(0, None, None), Ellipsis]
2763+
), "We don't support slicing tensor on dynamic shape. "
27582764

27592765
def num_slice_types(slices):
27602766
"""
@@ -2776,9 +2782,6 @@ def slice_to_trt_params(py_slice, dim_size):
27762782
size = math.ceil((stop - start) * 1.0 / stride)
27772783
return start, size, stride
27782784

2779-
if not isinstance(slices, tuple) and not isinstance(slices, list):
2780-
slices = (slices,)
2781-
27822785
if network.has_implicit_batch_dimension:
27832786
# Raise an error if it's trying to subscript batch dimension unless it's
27842787
# slice(None, None, None).
@@ -2831,12 +2834,17 @@ def slice_to_trt_params(py_slice, dim_size):
28312834
stride.append(1)
28322835
i += 1
28332836

2837+
if dynamic_shape:
2838+
size = get_shape_with_dynamic_shape(network, size, input_val, target, name)
2839+
28342840
layer = network.add_slice(
28352841
input=input_val,
28362842
start=start,
2837-
shape=size,
2843+
shape=[] if dynamic_shape else size,
28382844
stride=stride,
28392845
)
2846+
if dynamic_shape:
2847+
layer.set_input(2, size)
28402848
set_layer_name(layer, target, name)
28412849

28422850
# Add shuffle layer to insert dimensions for 'None' and remove dimensions for 'int'.
@@ -3212,15 +3220,15 @@ def acc_ops_chunk(
32123220
"of the TensorRT region!"
32133221
)
32143222

3223+
dynamic_shape = has_dynamic_shape(input_val.shape)
32153224
if network.has_implicit_batch_dimension:
32163225
input_dim_size += 1
32173226
dim = get_positive_dim(dim, input_dim_size)
32183227
assert dim != 0, "Can't chunk on batch dim when it's implicit!"
32193228
dim -= 1
32203229
else:
3221-
assert not has_dynamic_shape(
3222-
input_val.shape
3223-
), "We currently don't support dynamic shape for chunk."
3230+
if dynamic_shape:
3231+
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
32243232
dim = get_positive_dim(dim, input_dim_size)
32253233

32263234
if chunks > input_val.shape[dim]:
@@ -3243,8 +3251,16 @@ def acc_ops_chunk(
32433251
for i in range(chunks):
32443252
shape = list(input_val.shape)
32453253
shape[dim] = min(split_size, max_offset - offset)
3254+
if dynamic_shape:
3255+
shape = get_shape_with_dynamic_shape(
3256+
network, shape, input_val, target, f"{name}_{i}"
3257+
)
32463258
start[dim] = offset
3247-
layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
3259+
layer = network.add_slice(
3260+
input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride
3261+
)
3262+
if dynamic_shape:
3263+
layer.set_input(2, shape)
32483264
offset += split_size
32493265
set_layer_name(layer, target, f"{name}_{i}")
32503266
output.append(layer.get_output(0))

fx/converters/converter_utils.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ def has_dynamic_shape(shape: Shape) -> bool:
146146
Returns:
147147
A boolean value indicates whether there's dynamic dim in the shape.
148148
"""
149-
return any(s == -1 for s in shape)
149+
count = 0
150+
for s in shape:
151+
count += 1 if s == -1 else 0
152+
return count
150153

151154

152155
def get_axes_for_reduce_op(
@@ -342,6 +345,63 @@ def broadcast(
342345
return a, b
343346

344347

348+
def get_shape_with_dynamic_shape(
349+
network: TRTNetwork,
350+
shape: Union[list, tuple, torch.Tensor],
351+
input_val: TRTTensor,
352+
target: Target,
353+
name: str,
354+
) -> TRTTensor:
355+
"""
356+
Prepare the real output tensor shape for dynamic shape mode tensor input.
357+
How this functions works:
358+
Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation
359+
output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual
360+
reduce operation output shape. Steps of calculations are:
361+
1. get the actual tensor shape of input_val via add_shape layer;
362+
2. create a all 0 tensor [0, 0, 0];
363+
3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False];
364+
4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace
365+
all -1 dynamic shape dimensions with actual batch_size value;
366+
5. output shape with actual batch_size as [2048, 128, 256]
367+
368+
Args:
369+
network (TRTNetwork): TensorRT network object.
370+
shape: calculated shape of the expected output tensor
371+
input_val (TRTTensor): A TensorRT ITensor.
372+
target (Target): Target of fx node.
373+
name (str): The name we want to assign to the created TensorRT layer.
374+
Returns:
375+
TensorRT ITensors that represents the actual shape of the input_val
376+
"""
377+
# Ger real shape info for input_val
378+
input_shape = network.add_shape(input_val).get_output(0)
379+
380+
scale_layer = network.add_constant(
381+
input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32)
382+
)
383+
set_layer_name(scale_layer, target, f"{name}_scale")
384+
scale_res = scale_layer.get_output(0)
385+
386+
length = input_shape.shape[0]
387+
zero_layer = network.add_constant(
388+
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
389+
)
390+
set_layer_name(zero_layer, target, f"{name}_zeros")
391+
392+
condition_val = add_binary_elementwise_layer(
393+
network,
394+
scale_res,
395+
zero_layer.get_output(0),
396+
trt.ElementWiseOperation.LESS,
397+
target,
398+
f"{name}_shape",
399+
)
400+
select_layer = network.add_select(condition_val, input_shape, scale_res)
401+
set_layer_name(select_layer, target, f"{name}_select")
402+
return select_layer.get_output(0)
403+
404+
345405
def add_binary_elementwise_layer(
346406
network: TRTNetwork,
347407
lhs_val: Union[int, float, TRTTensor, torch.Tensor],

test/converters/acc_op/test_chunk.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44
from parameterized import parameterized
5-
from torch.testing._internal.common_fx2trt import AccTestCase
5+
from torch.testing._internal.common_fx2trt import InputTensorSpec, AccTestCase
66
from torch.testing._internal.common_utils import run_tests
77

88

@@ -26,6 +26,29 @@ def forward(self, x):
2626
expected_ops={acc_ops.chunk},
2727
)
2828

29+
@parameterized.expand(
30+
[
31+
("chunk", 3, 1),
32+
("chunk", 2000, 1),
33+
("chunk", 3, -2),
34+
]
35+
)
36+
def test_chunk_with_dynamic_shape(self, _, chunk, dim):
37+
class Chunk(nn.Module):
38+
def forward(self, x):
39+
return x.chunk(chunk, dim)[0]
40+
41+
input_specs = [
42+
InputTensorSpec(
43+
shape=(-1, 10, -1),
44+
dtype=torch.float32,
45+
shape_ranges=[((1, 10, 20), (5, 10, 20), (10, 10, 20))],
46+
),
47+
]
48+
self.run_test_with_dynamic_shape(
49+
Chunk(), input_specs, expected_ops={acc_ops.chunk}
50+
)
51+
2952

3053
if __name__ == "__main__":
3154
run_tests()

test/converters/acc_op/test_getitem.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44
from parameterized import parameterized
5-
from torch.testing._internal.common_fx2trt import AccTestCase
5+
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
66
from torch.testing._internal.common_utils import run_tests
77

88

@@ -52,6 +52,98 @@ def forward(self, x):
5252
inputs = [torch.randn(2, 10, 10, 10)]
5353
self.run_test(Getitem(idx), inputs, expected_ops={acc_ops.getitem})
5454

55+
@parameterized.expand(
56+
[
57+
("slice_batch_dim", slice(None, None, None)),
58+
("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))),
59+
(
60+
"slice_all_none",
61+
(slice(None, None, None), slice(None, None, None)),
62+
),
63+
(
64+
"slice_end_none",
65+
(slice(None, None, None), slice(None, None, None), slice(1, None, 1)),
66+
),
67+
(
68+
"slice_step_none",
69+
(slice(None, None, None), slice(None, None, None), slice(0, 3, None)),
70+
),
71+
("slice_neg_idx", (slice(None, None, None), -1, slice(None, None, None))),
72+
(
73+
"slice_neg_slice",
74+
(slice(None, None, None), slice(None, None, None), slice(-8, -2, 3)),
75+
),
76+
("multi_dim", (slice(None, None, None), 0, 1)),
77+
(
78+
"slice_multi_dim",
79+
(slice(None, None, None), slice(0, 3, 2), slice(1, -1, 3)),
80+
),
81+
(
82+
"none",
83+
(slice(None, None, None), None, slice(1, -1, 3)),
84+
),
85+
]
86+
)
87+
def test_getitem_with_dynamic_shape(self, _, idx):
88+
class Getitem(nn.Module):
89+
def __init__(self, idx):
90+
super().__init__()
91+
self.idx = idx
92+
93+
def forward(self, x):
94+
x = x + x
95+
return x[self.idx]
96+
97+
input_specs = [
98+
InputTensorSpec(
99+
shape=(-1, 256, 256),
100+
dtype=torch.float32,
101+
shape_ranges=[((1, 256, 256), (3, 256, 256), (5, 256, 256))],
102+
),
103+
]
104+
self.run_test_with_dynamic_shape(
105+
Getitem(idx), input_specs, expected_ops={acc_ops.getitem}
106+
)
107+
108+
@parameterized.expand(
109+
[
110+
("slice_batch_dim", slice(None, None, None)),
111+
("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))),
112+
(
113+
"slice_all_none",
114+
(slice(None, None, None), slice(None, None, None)),
115+
),
116+
(
117+
"slice_end_none",
118+
(slice(None, None, None), slice(None, None, None), slice(1, None, 1)),
119+
),
120+
(
121+
"slice_step_none",
122+
(slice(None, None, None), slice(None, None, None), slice(0, 3, None)),
123+
),
124+
]
125+
)
126+
def test_getitem_with_multi_dynamic_shape(self, _, idx):
127+
class Getitem(nn.Module):
128+
def __init__(self, idx):
129+
super().__init__()
130+
self.idx = idx
131+
132+
def forward(self, x):
133+
x = x + x
134+
return x[self.idx]
135+
136+
input_specs = [
137+
InputTensorSpec(
138+
shape=(-1, -1, 256),
139+
dtype=torch.float32,
140+
shape_ranges=[((1, 128, 256), (3, 192, 256), (5, 256, 256))],
141+
),
142+
]
143+
self.run_test_with_dynamic_shape(
144+
Getitem(idx), input_specs, expected_ops={acc_ops.getitem}
145+
)
146+
55147

56148
if __name__ == "__main__":
57149
run_tests()

0 commit comments

Comments
 (0)