Skip to content

Commit f32c7a9

Browse files
authored
feat: dynamic shape support for pad ops (#3045)
1 parent 6321710 commit f32c7a9

File tree

3 files changed

+504
-83
lines changed

3 files changed

+504
-83
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2943,7 +2943,9 @@ def aten_ops_addmm(
29432943
)
29442944

29452945

2946-
@dynamo_tensorrt_converter(torch.ops.aten.constant_pad_nd.default)
2946+
@dynamo_tensorrt_converter(
2947+
torch.ops.aten.constant_pad_nd.default, supports_dynamic_shapes=True
2948+
)
29472949
@enforce_tensor_types(
29482950
{
29492951
0: (TRTTensor,),
@@ -2967,9 +2969,15 @@ def aten_ops_constant_pad(
29672969
)
29682970

29692971

2970-
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad1d.default)
2971-
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad2d.default)
2972-
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad3d.default)
2972+
@dynamo_tensorrt_converter(
2973+
torch.ops.aten.reflection_pad1d.default, supports_dynamic_shapes=True
2974+
)
2975+
@dynamo_tensorrt_converter(
2976+
torch.ops.aten.reflection_pad2d.default, supports_dynamic_shapes=True
2977+
)
2978+
@dynamo_tensorrt_converter(
2979+
torch.ops.aten.reflection_pad3d.default, supports_dynamic_shapes=True
2980+
)
29732981
@enforce_tensor_types(
29742982
{
29752983
0: (TRTTensor,),
@@ -2992,9 +3000,15 @@ def aten_ops_reflection_pad(
29923000
)
29933001

29943002

2995-
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad1d.default)
2996-
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad2d.default)
2997-
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad3d.default)
3003+
@dynamo_tensorrt_converter(
3004+
torch.ops.aten.replication_pad1d.default, supports_dynamic_shapes=True
3005+
)
3006+
@dynamo_tensorrt_converter(
3007+
torch.ops.aten.replication_pad2d.default, supports_dynamic_shapes=True
3008+
)
3009+
@dynamo_tensorrt_converter(
3010+
torch.ops.aten.replication_pad3d.default, supports_dynamic_shapes=True
3011+
)
29983012
@enforce_tensor_types(
29993013
{
30003014
0: (TRTTensor,),
@@ -3017,7 +3031,9 @@ def aten_ops_replication_pad(
30173031
)
30183032

30193033

3020-
@dynamo_tensorrt_converter(torch.ops.aten._pad_circular.default)
3034+
@dynamo_tensorrt_converter(
3035+
torch.ops.aten._pad_circular.default, supports_dynamic_shapes=True
3036+
)
30213037
@enforce_tensor_types(
30223038
{
30233039
0: (TRTTensor,),
@@ -3040,7 +3056,7 @@ def aten_ops_circular_pad(
30403056
)
30413057

30423058

3043-
@dynamo_tensorrt_converter(torch.ops.aten.pad.default)
3059+
@dynamo_tensorrt_converter(torch.ops.aten.pad.default, supports_dynamic_shapes=True)
30443060
@enforce_tensor_types(
30453061
{
30463062
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/pad.py

Lines changed: 134 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from typing import Optional, Sequence, Union
22

3+
import numpy as np
34
import tensorrt as trt
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion import impl
68
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
8-
from torch_tensorrt.fx.converters.converter_utils import (
9-
has_dynamic_shape,
9+
from torch_tensorrt.dynamo.conversion.converter_utils import (
10+
get_trt_tensor,
1011
set_layer_name,
1112
)
12-
from torch_tensorrt.fx.types import TRTTensor
13+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
14+
from torch_tensorrt.dynamo.types import TRTTensor
1315

1416
"""
1517
Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
@@ -18,39 +20,109 @@
1820
"""
1921

2022

21-
def constant_padNd(
23+
def get_padded_shape_tensors(
2224
ctx: ConversionContext,
2325
target: Union[Target, str],
2426
source_ir: Optional[SourceIR],
2527
name: str,
2628
input: TRTTensor,
2729
pad: Sequence[int],
28-
value: Union[int, float] = 0,
2930
) -> TRTTensor:
30-
if has_dynamic_shape(input.shape):
31-
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
32-
3331
rank = len(input.shape)
34-
3532
if len(pad) // 2 > rank:
3633
raise RuntimeError(
37-
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension."
34+
f"Trying to pad last {len(pad) // 2} dimensions but the input only has {rank} dimensions."
3835
)
3936

37+
input_shape_tensor = get_shape_with_dynamic_shape(
38+
ctx,
39+
target,
40+
source_ir,
41+
name + "_input_shape",
42+
input.shape,
43+
input,
44+
)
45+
padded_shape_tensor = input_shape_tensor
46+
4047
start_list = [0] * rank
41-
new_shape = list(input.shape)
48+
for i in range(len(pad) // 2):
49+
dim_index = rank - (i + 1)
50+
pad_before = pad[i * 2]
51+
pad_after = pad[i * 2 + 1]
4252

43-
for i in range(0, len(pad) // 2):
44-
start_list[-i - 1] = -pad[i * 2]
45-
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]
53+
pad_sum = get_trt_tensor(
54+
ctx, pad_before + pad_after, f"{name}_pad_sum_{i}", dtype=np.int32
55+
)
56+
dim_shape = ctx.net.add_slice(
57+
input_shape_tensor,
58+
start=(dim_index,),
59+
shape=(1,),
60+
stride=(1,),
61+
).get_output(0)
62+
63+
new_dim_shape = impl.elementwise.add(
64+
ctx, target, source_ir, f"{name}_shape_dim_{i}", dim_shape, pad_sum
65+
)
66+
start_list[dim_index] = -pad_before
67+
68+
slices = []
69+
for j in range(rank):
70+
if j == dim_index:
71+
slices.append(new_dim_shape)
72+
else:
73+
slices.append(
74+
ctx.net.add_slice(
75+
padded_shape_tensor,
76+
start=(j,),
77+
shape=(1,),
78+
stride=(1,),
79+
).get_output(0)
80+
)
81+
padded_shape_tensor = impl.cat.cat(
82+
ctx, target, source_ir, f"{name}_cat", slices, 0
83+
)
84+
85+
start_indices_tensor = get_trt_tensor(
86+
ctx,
87+
np.array(start_list, dtype=np.int32),
88+
f"{name}_start_indices_tensor",
89+
dtype=np.int32,
90+
)
91+
92+
return start_indices_tensor, padded_shape_tensor
93+
94+
95+
def constant_padNd(
96+
ctx: ConversionContext,
97+
target: Union[Target, str],
98+
source_ir: Optional[SourceIR],
99+
name: str,
100+
input: TRTTensor,
101+
pad: Sequence[int],
102+
value: Union[int, float] = 0,
103+
) -> TRTTensor:
104+
105+
rank = len(input.shape)
106+
107+
start_indices_tensor, padded_shape_tensor = get_padded_shape_tensors(
108+
ctx, target, source_ir, name, input, pad
109+
)
46110

47111
stride_list = [1] * rank
112+
stride_tensor = get_trt_tensor(
113+
ctx,
114+
np.array(stride_list, dtype=np.int32),
115+
f"{name}_stride_tensor",
116+
dtype=np.int32,
117+
)
118+
48119
layer = ctx.net.add_slice(
49-
input,
50-
start=tuple(start_list),
51-
shape=tuple(new_shape),
52-
stride=tuple(stride_list),
120+
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
53121
)
122+
layer.set_input(1, start_indices_tensor)
123+
layer.set_input(2, padded_shape_tensor)
124+
layer.set_input(3, stride_tensor)
125+
54126
value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype)
55127
layer.set_input(4, value_const)
56128
layer.mode = trt.SampleMode.FILL
@@ -67,30 +139,26 @@ def reflection_padNd(
67139
input: TRTTensor,
68140
padding: Sequence[int],
69141
) -> TRTTensor:
70-
if has_dynamic_shape(input.shape):
71-
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
72-
73142
rank = len(input.shape)
74143

75-
if len(padding) // 2 > rank:
76-
raise RuntimeError(
77-
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension."
78-
)
79-
80-
start_list = [0] * rank
81-
new_shape = list(input.shape)
82-
83-
for i in range(0, len(padding) // 2):
84-
start_list[-i - 1] = -padding[i * 2]
85-
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]
144+
start_indices_tensor, padded_shape_tensor = get_padded_shape_tensors(
145+
ctx, target, source_ir, name, input, padding
146+
)
86147

87148
stride_list = [1] * rank
149+
stride_tensor = get_trt_tensor(
150+
ctx,
151+
np.array(stride_list, dtype=np.int32),
152+
f"{name}_stride_tensor",
153+
dtype=np.int32,
154+
)
155+
88156
layer = ctx.net.add_slice(
89-
input,
90-
start=tuple(start_list),
91-
shape=tuple(new_shape),
92-
stride=tuple(stride_list),
157+
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
93158
)
159+
layer.set_input(1, start_indices_tensor)
160+
layer.set_input(2, padded_shape_tensor)
161+
layer.set_input(3, stride_tensor)
94162
layer.mode = trt.SampleMode.REFLECT
95163

96164
set_layer_name(layer, target, name, source_ir)
@@ -105,30 +173,26 @@ def replication_padNd(
105173
input: TRTTensor,
106174
padding: Sequence[int],
107175
) -> TRTTensor:
108-
if has_dynamic_shape(input.shape):
109-
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
110-
111176
rank = len(input.shape)
112177

113-
if len(padding) // 2 > rank:
114-
raise RuntimeError(
115-
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension."
116-
)
117-
118-
start_list = [0] * rank
119-
new_shape = list(input.shape)
120-
121-
for i in range(0, len(padding) // 2):
122-
start_list[-i - 1] = -padding[i * 2]
123-
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]
178+
start_indices_tensor, padded_shape_tensor = get_padded_shape_tensors(
179+
ctx, target, source_ir, name, input, padding
180+
)
124181

125182
stride_list = [1] * rank
183+
stride_tensor = get_trt_tensor(
184+
ctx,
185+
np.array(stride_list, dtype=np.int32),
186+
f"{name}_stride_tensor",
187+
dtype=np.int32,
188+
)
189+
126190
layer = ctx.net.add_slice(
127-
input,
128-
start=tuple(start_list),
129-
shape=tuple(new_shape),
130-
stride=tuple(stride_list),
191+
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
131192
)
193+
layer.set_input(1, start_indices_tensor)
194+
layer.set_input(2, padded_shape_tensor)
195+
layer.set_input(3, stride_tensor)
132196
layer.mode = trt.SampleMode.CLAMP
133197

134198
set_layer_name(layer, target, name, source_ir)
@@ -141,32 +205,28 @@ def circular_padNd(
141205
source_ir: Optional[SourceIR],
142206
name: str,
143207
input: TRTTensor,
144-
pad: Sequence[int],
208+
padding: Sequence[int],
145209
) -> TRTTensor:
146-
if has_dynamic_shape(input.shape):
147-
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
148-
149210
rank = len(input.shape)
150211

151-
if len(pad) // 2 > rank:
152-
raise RuntimeError(
153-
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension."
154-
)
155-
156-
start_list = [0] * rank
157-
new_shape = list(input.shape)
158-
159-
for i in range(0, len(pad) // 2):
160-
start_list[-i - 1] = -pad[i * 2]
161-
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]
212+
start_indices_tensor, padded_shape_tensor = get_padded_shape_tensors(
213+
ctx, target, source_ir, name, input, padding
214+
)
162215

163216
stride_list = [1] * rank
217+
stride_tensor = get_trt_tensor(
218+
ctx,
219+
np.array(stride_list, dtype=np.int32),
220+
f"{name}_stride_tensor",
221+
dtype=np.int32,
222+
)
223+
164224
layer = ctx.net.add_slice(
165-
input,
166-
start=tuple(start_list),
167-
shape=tuple(new_shape),
168-
stride=tuple(stride_list),
225+
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
169226
)
227+
layer.set_input(1, start_indices_tensor)
228+
layer.set_input(2, padded_shape_tensor)
229+
layer.set_input(3, stride_tensor)
170230
layer.mode = trt.SampleMode.WRAP
171231

172232
set_layer_name(layer, target, name, source_ir)

0 commit comments

Comments
 (0)