Skip to content

Commit 201b377

Browse files
add dynamic shape support for amax/amin/max/min/prod/sum (#2943)
ci pipeline is failing not releated to this change, still merge it in as we need to do the branch cut for release/2.4
1 parent 6aa439b commit 201b377

File tree

7 files changed

+263
-11
lines changed

7 files changed

+263
-11
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,7 @@ def aten_ops_expand(
11461146
)
11471147

11481148

1149-
@dynamo_tensorrt_converter(torch.ops.aten.amax.default)
1149+
@dynamo_tensorrt_converter(torch.ops.aten.amax.default, supports_dynamic_shapes=True)
11501150
@enforce_tensor_types(
11511151
{
11521152
0: (TRTTensor,),
@@ -1170,7 +1170,7 @@ def aten_ops_amax(
11701170
)
11711171

11721172

1173-
@dynamo_tensorrt_converter(torch.ops.aten.amin.default)
1173+
@dynamo_tensorrt_converter(torch.ops.aten.amin.default, supports_dynamic_shapes=True)
11741174
@enforce_tensor_types(
11751175
{
11761176
0: (TRTTensor,),
@@ -1194,9 +1194,9 @@ def aten_ops_amin(
11941194
)
11951195

11961196

1197-
@dynamo_tensorrt_converter(torch.ops.aten.sum.default)
1198-
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList)
1199-
@dynamo_tensorrt_converter(torch.ops.prims.sum.default)
1197+
@dynamo_tensorrt_converter(torch.ops.aten.sum.default, supports_dynamic_shapes=True)
1198+
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList, supports_dynamic_shapes=True)
1199+
@dynamo_tensorrt_converter(torch.ops.prims.sum.default, supports_dynamic_shapes=True)
12001200
def aten_ops_sum(
12011201
ctx: ConversionContext,
12021202
target: Target,
@@ -1228,8 +1228,8 @@ def aten_ops_sum(
12281228
return sum_
12291229

12301230

1231-
@dynamo_tensorrt_converter(torch.ops.aten.prod.default)
1232-
@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int)
1231+
@dynamo_tensorrt_converter(torch.ops.aten.prod.default, supports_dynamic_shapes=True)
1232+
@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int, supports_dynamic_shapes=True)
12331233
def aten_ops_prod(
12341234
ctx: ConversionContext,
12351235
target: Target,
@@ -1248,9 +1248,14 @@ def aten_ops_prod(
12481248
)
12491249

12501250

1251-
@dynamo_tensorrt_converter(torch.ops.aten.max.default)
12521251
@dynamo_tensorrt_converter(
1253-
torch.ops.aten.max.dim, capability_validator=one_user_validator
1252+
torch.ops.aten.max.default,
1253+
supports_dynamic_shapes=True,
1254+
)
1255+
@dynamo_tensorrt_converter(
1256+
torch.ops.aten.max.dim,
1257+
capability_validator=one_user_validator,
1258+
supports_dynamic_shapes=True,
12541259
)
12551260
def aten_ops_max(
12561261
ctx: ConversionContext,
@@ -1271,9 +1276,14 @@ def aten_ops_max(
12711276
)
12721277

12731278

1274-
@dynamo_tensorrt_converter(torch.ops.aten.min.default)
12751279
@dynamo_tensorrt_converter(
1276-
torch.ops.aten.min.dim, capability_validator=one_user_validator
1280+
torch.ops.aten.min.default,
1281+
supports_dynamic_shapes=True,
1282+
)
1283+
@dynamo_tensorrt_converter(
1284+
torch.ops.aten.min.dim,
1285+
capability_validator=one_user_validator,
1286+
supports_dynamic_shapes=True,
12771287
)
12781288
def aten_ops_min(
12791289
ctx: ConversionContext,

tests/py/dynamo/conversion/test_amax_aten.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -90,6 +91,38 @@ def forward(self, x):
9091
check_dtype=False,
9192
)
9293

94+
@parameterized.expand(
95+
[
96+
((0, 1), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
97+
((0,), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
98+
(1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
99+
(2, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
100+
(-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
101+
((-1, 0), True, (2, 2, 5), (3, 3, 6), (4, 5, 7)),
102+
]
103+
)
104+
def test_amax_dynamic_shape(self, dim, keep_dim, min_shape, opt_shape, max_shape):
105+
class Amax(nn.Module):
106+
def __init__(self, dim):
107+
super().__init__()
108+
self.dim = dim
109+
110+
def forward(self, x):
111+
return torch.ops.aten.amax.default(x, dim, keep_dim)
112+
113+
input_specs = [
114+
Input(
115+
dtype=torch.float32,
116+
min_shape=min_shape,
117+
opt_shape=opt_shape,
118+
max_shape=max_shape,
119+
),
120+
]
121+
self.run_test_with_dynamic_shape(
122+
Amax(dim),
123+
input_specs,
124+
)
125+
93126

94127
if __name__ == "__main__":
95128
run_tests()

tests/py/dynamo/conversion/test_amin_aten.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -90,6 +91,38 @@ def forward(self, x):
9091
check_dtype=False,
9192
)
9293

94+
@parameterized.expand(
95+
[
96+
((0, 1), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
97+
((0,), False, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
98+
(1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
99+
(2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
100+
(-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
101+
((-1, 0), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
102+
]
103+
)
104+
def test_amin_dynamic_shape(self, dim, keep_dim, min_shape, opt_shape, max_shape):
105+
class Amin(nn.Module):
106+
def __init__(self, dim):
107+
super().__init__()
108+
self.dim = dim
109+
110+
def forward(self, x):
111+
return torch.ops.aten.amin.default(x, dim, keep_dim)
112+
113+
input_specs = [
114+
Input(
115+
dtype=torch.float32,
116+
min_shape=min_shape,
117+
opt_shape=opt_shape,
118+
max_shape=max_shape,
119+
),
120+
]
121+
self.run_test_with_dynamic_shape(
122+
Amin(dim),
123+
input_specs,
124+
)
125+
93126

94127
if __name__ == "__main__":
95128
run_tests()

tests/py/dynamo/conversion/test_max_aten.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -65,6 +66,62 @@ def forward(self, x):
6566
check_dtype=False,
6667
)
6768

69+
@parameterized.expand(
70+
[
71+
(1, True, (2, 2, 3), (2, 3, 3), (3, 3, 4)),
72+
(2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
73+
(-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
74+
]
75+
)
76+
def test_max_dim_dynamic_shape(
77+
self, dim, keep_dim, min_shape, opt_shape, max_shape
78+
):
79+
class Max(nn.Module):
80+
def __init__(self, dim):
81+
super().__init__()
82+
self.dim = dim
83+
84+
def forward(self, x):
85+
return torch.ops.aten.max.dim(x, dim, keep_dim)[0]
86+
87+
input_specs = [
88+
Input(
89+
dtype=torch.float32,
90+
min_shape=min_shape,
91+
opt_shape=opt_shape,
92+
max_shape=max_shape,
93+
),
94+
]
95+
self.run_test_with_dynamic_shape(
96+
Max(dim),
97+
input_specs,
98+
)
99+
100+
@parameterized.expand(
101+
[
102+
((2, 2, 3), (2, 3, 3), (3, 3, 4)),
103+
((2, 3, 5), (3, 4, 6), (4, 5, 7)),
104+
((2, 3, 5), (3, 4, 6), (4, 5, 7)),
105+
]
106+
)
107+
def test_max_default_dynamic_shape(self, min_shape, opt_shape, max_shape):
108+
class Max(nn.Module):
109+
def forward(self, x):
110+
return torch.ops.aten.max.default(x)
111+
112+
input_specs = [
113+
Input(
114+
dtype=torch.float32,
115+
min_shape=min_shape,
116+
opt_shape=opt_shape,
117+
max_shape=max_shape,
118+
),
119+
]
120+
self.run_test_with_dynamic_shape(
121+
Max(),
122+
input_specs,
123+
)
124+
68125

69126
if __name__ == "__main__":
70127
run_tests()

tests/py/dynamo/conversion/test_min_aten.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -65,6 +66,62 @@ def forward(self, x):
6566
check_dtype=False,
6667
)
6768

69+
@parameterized.expand(
70+
[
71+
(1, True, (2, 2, 3), (2, 3, 3), (3, 3, 4)),
72+
(2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
73+
(-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)),
74+
]
75+
)
76+
def test_min_dim_dynamic_shape(
77+
self, dim, keep_dim, min_shape, opt_shape, max_shape
78+
):
79+
class Min(nn.Module):
80+
def __init__(self, dim):
81+
super().__init__()
82+
self.dim = dim
83+
84+
def forward(self, x):
85+
return torch.ops.aten.min.dim(x, dim, keep_dim)[0]
86+
87+
input_specs = [
88+
Input(
89+
dtype=torch.float32,
90+
min_shape=min_shape,
91+
opt_shape=opt_shape,
92+
max_shape=max_shape,
93+
),
94+
]
95+
self.run_test_with_dynamic_shape(
96+
Min(dim),
97+
input_specs,
98+
)
99+
100+
@parameterized.expand(
101+
[
102+
((2, 2, 3), (2, 3, 3), (3, 3, 4)),
103+
((2, 3, 5), (3, 4, 6), (4, 5, 7)),
104+
((2, 3, 5), (3, 4, 6), (4, 5, 7)),
105+
]
106+
)
107+
def test_min_default_dynamic_shape(self, min_shape, opt_shape, max_shape):
108+
class Min(nn.Module):
109+
def forward(self, x):
110+
return torch.ops.aten.min.default(x)
111+
112+
input_specs = [
113+
Input(
114+
dtype=torch.float32,
115+
min_shape=min_shape,
116+
opt_shape=opt_shape,
117+
max_shape=max_shape,
118+
),
119+
]
120+
self.run_test_with_dynamic_shape(
121+
Min(),
122+
input_specs,
123+
)
124+
68125

69126
if __name__ == "__main__":
70127
run_tests()

tests/py/dynamo/conversion/test_prod_aten.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -68,6 +69,33 @@ def forward(self, x):
6869
use_dynamo_tracer=True,
6970
)
7071

72+
@parameterized.expand(
73+
[
74+
(0, (2, 3), (2, 4), (3, 5)),
75+
(1, (2, 3), (2, 4), (3, 5)),
76+
(2, (2, 2, 4), (2, 3, 4), (3, 4, 5)),
77+
(-1, (2, 2, 4), (2, 3, 4), (3, 4, 5)),
78+
]
79+
)
80+
def test_prod_dynamic_shape(self, dim, min_shape, opt_shape, max_shape):
81+
class Prod(nn.Module):
82+
def forward(self, x):
83+
return torch.prod(x, dim)
84+
85+
input_specs = [
86+
Input(
87+
dtype=torch.float32,
88+
min_shape=min_shape,
89+
opt_shape=opt_shape,
90+
max_shape=max_shape,
91+
),
92+
]
93+
self.run_test_with_dynamic_shape(
94+
Prod(),
95+
input_specs,
96+
use_dynamo_tracer=True,
97+
)
98+
7199

72100
if __name__ == "__main__":
73101
run_tests()

0 commit comments

Comments
 (0)