Skip to content

Commit ffa4f64

Browse files
authored
tile dynamic dim (#3085)
1 parent 39f8255 commit ffa4f64

File tree

3 files changed

+105
-5
lines changed

3 files changed

+105
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ def aten_ops_cumsum(
953953
)
954954

955955

956-
@dynamo_tensorrt_converter(torch.ops.aten.tile.default)
956+
@dynamo_tensorrt_converter(torch.ops.aten.tile.default, supports_dynamic_shapes=True)
957957
@enforce_tensor_types(
958958
{
959959
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def tile(
457457
dims: Sequence[int],
458458
) -> TRTTensor:
459459
diff = len(dims) - len(input.shape)
460+
has_dynamic_shape_input = has_dynamic_shape(input.shape)
460461
if diff > 0:
461462
# prepend 1 to input.shape
462463
new_shape = (1,) * diff + tuple(input.shape)
@@ -467,10 +468,64 @@ def tile(
467468
# prepend 1 to dims
468469
dims = (1,) * -diff + tuple(dims)
469470

470-
shapes = [i * j for i, j in zip(input.shape, dims)]
471-
starts = [0] * len(dims)
472-
strides = [1] * len(dims)
473-
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
471+
starts = tuple([0] * len(dims))
472+
strides = tuple([1] * len(dims))
473+
# layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
474+
if not (has_dynamic_shape_input):
475+
shapes = [i * j for i, j in zip(input.shape, dims)]
476+
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
477+
else:
478+
shapes = []
479+
index = 0
480+
for i, j in zip(input.shape, dims):
481+
if i == DYNAMIC_DIM:
482+
i = get_shape(
483+
ctx, target, source_ir, name + f"_input_{index}", input, index
484+
)
485+
prod_shape = convert_binary_elementwise(
486+
ctx,
487+
target,
488+
source_ir,
489+
name + "_prod",
490+
trt.ElementWiseOperation.PROD,
491+
i,
492+
j,
493+
)
494+
shapes.append(prod_shape)
495+
index = index + 1
496+
layer = ctx.net.add_slice(
497+
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
498+
)
499+
shape_tensor = cat(
500+
ctx,
501+
target,
502+
source_ir,
503+
name + "_shape_concat",
504+
tuple(shapes),
505+
0,
506+
cast_dtype=trt.int32,
507+
)
508+
start_tensor = cat(
509+
ctx,
510+
target,
511+
source_ir,
512+
name + "_start_concat",
513+
starts,
514+
0,
515+
cast_dtype=trt.int32,
516+
)
517+
stride_tensor = cat(
518+
ctx,
519+
target,
520+
source_ir,
521+
name + "_stride_concat",
522+
strides,
523+
0,
524+
cast_dtype=trt.int32,
525+
)
526+
layer.set_input(1, start_tensor)
527+
layer.set_input(2, shape_tensor)
528+
layer.set_input(3, stride_tensor)
474529
layer.mode = trt.SampleMode.WRAP
475530
set_layer_name(layer, target, name)
476531
return layer.get_output(0)

tests/py/dynamo/conversion/test_tile_aten.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -71,5 +72,49 @@ def forward(self, x):
7172
)
7273

7374

75+
class TestTileConverterDynamicShape(DispatchTestCase):
76+
@parameterized.expand(
77+
[
78+
((3,), (3,), (6,), (1,)),
79+
((3,), (3,), (6,), (0,)),
80+
((3,), (3,), (6,), (2,)),
81+
((2,), (3,), (6,), (2, 2)),
82+
((2,), (3,), (6,), (0, 2)),
83+
# 2d cases
84+
((3, 1), (3, 1), (6, 1), (0,)),
85+
((3, 1), (3, 1), (6, 1), (2,)),
86+
((2, 3), (2, 3), (4, 3), (2, 2)),
87+
((2, 3), (2, 3), (4, 3), (1, 0)),
88+
((2, 3), (2, 3), (4, 3), (0, 2)),
89+
((2, 3), (2, 3), (4, 3), (4, 2, 3)),
90+
((2, 3), (2, 3), (4, 3), (0, 0, 3)),
91+
((2, 3), (2, 3), (4, 3), (4, 2, 3, 1, 2)),
92+
# 3d cases
93+
((4, 2, 3), (4, 2, 3), (6, 2, 3), (2,)),
94+
((4, 2, 3), (4, 2, 3), (6, 2, 3), (1, 2)),
95+
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3)),
96+
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4)),
97+
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4, 5)),
98+
]
99+
)
100+
def test_tile_input_dynamic(self, min_shape, opt_shape, max_shape, dims):
101+
class Tile(nn.Module):
102+
def forward(self, x):
103+
return torch.ops.aten.tile.default(x, dims)
104+
105+
input_specs = [
106+
Input(
107+
min_shape=min_shape,
108+
opt_shape=opt_shape,
109+
max_shape=max_shape,
110+
dtype=torch.float32,
111+
),
112+
]
113+
self.run_test_with_dynamic_shape(
114+
Tile(),
115+
input_specs,
116+
)
117+
118+
74119
if __name__ == "__main__":
75120
run_tests()

0 commit comments

Comments
 (0)