Skip to content

Commit e50550d

Browse files
Add dynamic shape support for leaky_relu/elu/hard_sigmoid/softplus (#2927)
1 parent e30e5ac commit e50550d

File tree

5 files changed

+34
-20
lines changed

5 files changed

+34
-20
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,9 @@ def aten_ops_tanh(
459459
)
460460

461461

462-
@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default)
462+
@dynamo_tensorrt_converter(
463+
torch.ops.aten.leaky_relu.default, supports_dynamic_shapes=True
464+
)
463465
def aten_ops_leaky_relu(
464466
ctx: ConversionContext,
465467
target: Target,
@@ -477,7 +479,7 @@ def aten_ops_leaky_relu(
477479
)
478480

479481

480-
@dynamo_tensorrt_converter(torch.ops.aten.elu.default)
482+
@dynamo_tensorrt_converter(torch.ops.aten.elu.default, supports_dynamic_shapes=True)
481483
def aten_ops_elu(
482484
ctx: ConversionContext,
483485
target: Target,
@@ -496,7 +498,9 @@ def aten_ops_elu(
496498
)
497499

498500

499-
@dynamo_tensorrt_converter(torch.ops.aten.softplus.default)
501+
@dynamo_tensorrt_converter(
502+
torch.ops.aten.softplus.default, supports_dynamic_shapes=True
503+
)
500504
def aten_ops_softplus(
501505
ctx: ConversionContext,
502506
target: Target,
@@ -514,7 +518,9 @@ def aten_ops_softplus(
514518
)
515519

516520

517-
@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default)
521+
@dynamo_tensorrt_converter(
522+
torch.ops.aten.hardsigmoid.default, supports_dynamic_shapes=True
523+
)
518524
def aten_ops_hard_sigmoid(
519525
ctx: ConversionContext,
520526
target: Target,

tests/py/dynamo/conversion/test_elu_aten.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def forward(self, x):
2222

2323
input_specs = [
2424
Input(
25-
shape=(-1, -1, -1),
2625
dtype=torch.float32,
27-
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
min_shape=(1, 1, 1),
27+
opt_shape=(1, 2, 3),
28+
max_shape=(3, 3, 3),
2829
),
2930
]
3031
self.run_test_with_dynamic_shape(TestModule(), input_specs)
@@ -36,9 +37,10 @@ def forward(self, x):
3637

3738
input_specs = [
3839
Input(
39-
shape=(-1, -1, -1, -1),
4040
dtype=torch.float32,
41-
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
41+
min_shape=(1, 1, 1, 5),
42+
opt_shape=(1, 2, 3, 5),
43+
max_shape=(3, 3, 3, 5),
4244
),
4345
]
4446

tests/py/dynamo/conversion/test_hard_sigmoid_aten.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def forward(self, x):
2222

2323
input_specs = [
2424
Input(
25-
shape=(-1, -1, -1),
2625
dtype=torch.float32,
27-
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
min_shape=(1, 1, 1),
27+
opt_shape=(1, 2, 3),
28+
max_shape=(3, 3, 3),
2829
),
2930
]
3031
self.run_test_with_dynamic_shape(TestModule(), input_specs)
@@ -36,9 +37,10 @@ def forward(self, x):
3637

3738
input_specs = [
3839
Input(
39-
shape=(-1, -1, -1, -1),
4040
dtype=torch.float32,
41-
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
41+
min_shape=(1, 1, 1, 5),
42+
opt_shape=(1, 2, 3, 5),
43+
max_shape=(3, 3, 3, 5),
4244
),
4345
]
4446

tests/py/dynamo/conversion/test_leaky_relu_aten.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def forward(self, x):
2222

2323
input_specs = [
2424
Input(
25-
shape=(-1, -1, -1),
2625
dtype=torch.float32,
27-
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
min_shape=(1, 1, 1),
27+
opt_shape=(1, 2, 3),
28+
max_shape=(3, 3, 3),
2829
),
2930
]
3031
self.run_test_with_dynamic_shape(TestModule(), input_specs)
@@ -36,9 +37,10 @@ def forward(self, x):
3637

3738
input_specs = [
3839
Input(
39-
shape=(-1, -1, -1, -1),
4040
dtype=torch.float32,
41-
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
41+
min_shape=(1, 1, 1, 5),
42+
opt_shape=(1, 2, 3, 5),
43+
max_shape=(3, 3, 3, 5),
4244
),
4345
]
4446

tests/py/dynamo/conversion/test_softplus_aten.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def forward(self, x):
2222

2323
input_specs = [
2424
Input(
25-
shape=(-1, -1, -1),
2625
dtype=torch.float32,
27-
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
min_shape=(1, 1, 1),
27+
opt_shape=(1, 2, 3),
28+
max_shape=(3, 3, 3),
2829
),
2930
]
3031
self.run_test_with_dynamic_shape(TestModule(), input_specs)
@@ -36,9 +37,10 @@ def forward(self, x):
3637

3738
input_specs = [
3839
Input(
39-
shape=(-1, -1, -1, -1),
4040
dtype=torch.float32,
41-
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
41+
min_shape=(1, 1, 1, 5),
42+
opt_shape=(1, 2, 3, 5),
43+
max_shape=(3, 3, 3, 5),
4244
),
4345
]
4446

0 commit comments

Comments
 (0)