Skip to content

Commit 363d0bd

Browse files
fix unsqueeze cannot work on more than 1 dynamic_shape dimensions (#2933)
1 parent c7af229 commit 363d0bd

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
)
1010
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1111
from torch_tensorrt.fx.types import Shape, TRTTensor
12-
from torch_tensorrt.fx.utils import get_dynamic_dims
1312

1413

1514
def unsqueeze(
@@ -32,13 +31,49 @@ def unsqueeze(
3231
input_shape_size = len(input_val.shape)
3332
dim = get_positive_dim(dim, input_shape_size + 1)
3433

35-
assert (
36-
len(get_dynamic_dims(input_val.shape)) <= 1
37-
), "Currently we don't support unsqueeze with more than one dynamic dims."
38-
layer = ctx.net.add_shuffle(input_val)
39-
layer.reshape_dims = (
40-
tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
34+
intermediate_dim = 0
35+
dynamic_shape_cnt = 0
36+
# if unsqueeze the last dimensions, we can directly append to the shape
37+
if dim == input_shape_size:
38+
intermediate_dim = dim
39+
else:
40+
# since maximum of one dimension is permitted to be specified as -1
41+
# find the intermediate_dim which has only 1 dynamic_shape_cnt
42+
# and then we can add a transpose after reshape if it is not the final shape we want
43+
for i, s in reversed(list(enumerate(input_val.shape))):
44+
if i >= dim:
45+
if s == -1:
46+
dynamic_shape_cnt += 1
47+
if dynamic_shape_cnt > 1:
48+
intermediate_dim = i + 1
49+
break
50+
if i == dim:
51+
intermediate_dim = i
52+
break
53+
# calculate the new_shape for the shuffle layer's reshape_dims
54+
new_shape = list(
55+
tuple(input_val.shape)[:intermediate_dim]
56+
+ (1,)
57+
+ tuple(input_val.shape)[intermediate_dim:]
4158
)
59+
for i, s in enumerate(new_shape):
60+
if i < intermediate_dim and s == -1:
61+
new_shape[i] = 0
62+
layer = ctx.net.add_shuffle(input_val)
63+
layer.reshape_dims = tuple(new_shape)
64+
# if the intermediate_dim is not the final dim we want to unsqueeze, add a second_transpose after reshape
65+
if intermediate_dim != dim:
66+
# calculate the second_transpose for the shuffle layer
67+
permutation = [*range(0, len(new_shape))]
68+
# for example: if the reshape_dims is (3, 3, 5, 1, 5) and the final shape we want is (3, 1, 3, 5, 5)
69+
# here intermediate_dim=3, dim=1, we need to move intermediate_dim before [dim: intermediate_dim)
70+
new_permutation = (
71+
tuple(permutation[:dim])
72+
+ (intermediate_dim,)
73+
+ tuple(permutation[dim:intermediate_dim])
74+
+ tuple(permutation[intermediate_dim + 1 :])
75+
)
76+
layer.second_transpose = new_permutation
4277
set_layer_name(layer, target, name, source_ir)
4378
return layer.get_output(0)
4479

tests/py/dynamo/conversion/test_unsqueeze_aten.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import torch.nn as nn
44
from parameterized import parameterized
55
from torch.testing._internal.common_utils import run_tests
6-
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
7-
86
from torch_tensorrt import Input
7+
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
98

109
from .harness import DispatchTestCase
1110

@@ -29,16 +28,32 @@ def forward(self, x):
2928
inputs = [torch.randn(1, 2, 3)]
3029
self.run_test(Unsqueeze(dim), inputs)
3130

32-
# Testing with more than one dynamic dims results in following error:
33-
# AssertionError: Currently we don't support unsqueeze with more than one dynamic dims.
34-
3531
@parameterized.expand(
3632
[
37-
("negative_dim_dynamic", -4),
38-
("positive_dim_dynamic", 1),
33+
("1_dynamic_shape_2d_-3", -3, (2, 5), (3, 5), (4, 5)),
34+
("1_dynamic_shape_2d_-2", -2, (2, 3), (2, 4), (2, 5)),
35+
("1_dynamic_shape_2d_-1", -1, (2, 3), (2, 4), (2, 5)),
36+
("1_dynamic_shape_2d_0", 0, (2, 3), (2, 4), (2, 5)),
37+
("1_dynamic_shape_2d_1", 1, (2, 3), (2, 4), (2, 5)),
38+
("1_dynamic_shape_2d_2", 2, (2, 3), (2, 4), (2, 5)),
39+
("2_dynamic_shape_3d_-1", -1, (2, 2, 3), (4, 3, 3), (5, 5, 3)),
40+
("2_dynamic_shape_3d_0", 2, (2, 2, 3), (4, 3, 3), (5, 5, 3)),
41+
("2_dynamic_shape_3d_1", 1, (2, 2, 3), (4, 3, 3), (5, 6, 3)),
42+
("2_dynamic_shape_3d_2", 2, (2, 2, 3), (4, 3, 3), (6, 5, 3)),
43+
("4_dynamic_shape_4d_-4", -4, (1, 2, 3, 4), (2, 2, 3, 5), (3, 3, 5, 5)),
44+
("4_dynamic_shape_4d_-3", -3, (1, 2, 3, 4), (2, 2, 3, 5), (3, 3, 5, 5)),
45+
("4_dynamic_shape_4d_-2", -2, (1, 2, 3, 4), (2, 2, 3, 5), (4, 3, 5, 6)),
46+
("4_dynamic_shape_4d_-1", -1, (1, 2, 3, 4), (2, 2, 3, 5), (4, 3, 5, 6)),
47+
("4_dynamic_shape_4d_0", 0, (1, 2, 3, 4), (2, 2, 5, 7), (2, 3, 6, 8)),
48+
("4_dynamic_shape_4d_1", 1, (1, 2, 3, 4), (2, 2, 3, 5), (3, 3, 5, 5)),
49+
("4_dynamic_shape_4d_2", 2, (1, 2, 3, 4), (2, 2, 3, 5), (3, 3, 5, 5)),
50+
("4_dynamic_shape_4d_3", 3, (1, 2, 3, 4), (2, 2, 3, 5), (3, 3, 5, 5)),
51+
("4_dynamic_shape_4d_4", 4, (1, 2, 3, 4), (2, 2, 3, 5), (3, 3, 5, 5)),
3952
]
4053
)
41-
def test_unsqueeze_with_dynamic_shape(self, _, dim):
54+
def test_unsqueeze_with_dynamic_shape(
55+
self, _, dim, min_shape, opt_shape, max_shape
56+
):
4257
class Unsqueeze(nn.Module):
4358
def __init__(self, dim):
4459
super().__init__()
@@ -49,9 +64,10 @@ def forward(self, x):
4964

5065
input_specs = [
5166
Input(
52-
shape=(-1, 2, 3),
5367
dtype=torch.float32,
54-
shape_ranges=[((1, 2, 3), (2, 2, 3), (3, 2, 3))],
68+
min_shape=min_shape,
69+
opt_shape=opt_shape,
70+
max_shape=max_shape,
5571
),
5672
]
5773
self.run_test_with_dynamic_shape(Unsqueeze(dim), input_specs)

0 commit comments

Comments
 (0)