Skip to content

Commit 8e06793

Browse files
fix conv1d/deconv1d bug with stride more than 1 (#3737)
1 parent a93266a commit 8e06793

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
SourceIR,
1313
cast_trt_tensor,
14-
extend_attr_to_tuple,
1514
get_trt_tensor,
1615
has_dynamic_shape,
1716
set_layer_name,
@@ -159,10 +158,9 @@ def convNd(
159158
# Expand parameters manually for Conv1D computations
160159
if is_conv1d:
161160
padding = (tuple(padding) + (0,)) if padding is not None else padding
162-
stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride
163-
dilation = (
164-
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
165-
)
161+
# stride in conv1d is (2,) -> need to change to (2, 1) in conv2d
162+
stride = (stride[0], 1) if stride is not None else stride
163+
dilation = (dilation[0], 1) if dilation is not None else dilation
166164

167165
# Set relevant attributes of convolution layer
168166
if padding is not None:

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
SourceIR,
13-
extend_attr_to_tuple,
1413
get_trt_tensor,
1514
has_dynamic_shape,
1615
to_torch,
@@ -142,10 +141,9 @@ def deconvNd(
142141
# Expand parameters manually for Conv1D computations
143142
if is_deconv1d:
144143
padding = (tuple(padding) + (0,)) if padding is not None else padding
145-
stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride
146-
dilation = (
147-
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
148-
)
144+
# stride in deconv1d is (2,) -> need to change to (2, 1) in deconv2d
145+
stride = (stride[0], 1) if stride is not None else stride
146+
dilation = (dilation[0], 1) if dilation is not None else dilation
149147
output_padding = (
150148
(tuple(output_padding) + (0,))
151149
if output_padding is not None

tests/py/dynamo/conversion/test_convolution_aten.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class TestConvolutionConverter(DispatchTestCase):
1515
param("non_zero_padding", 1, padding=1),
1616
param("dilation", 1, dilation=2),
1717
param("groups", 1, groups=3),
18+
param("stride", 1, stride=1),
19+
param("stride_2", 1, stride=2),
20+
param("stride_tuple", 1, stride=(2,)),
1821
]
1922
)
2023
def test_conv1d(
@@ -52,6 +55,7 @@ def forward(self, x):
5255
("tuple_parameters", 1, (1), (1)),
5356
param("non_zero_padding", 1, padding=1),
5457
param("dilation", 1, dilation=2),
58+
param("stride", 1, stride=2),
5559
]
5660
)
5761
def test_conv1d_TRTTensor_weight(
@@ -140,6 +144,7 @@ def forward(self, x):
140144
param("tuple_dilation", 2, dilation=(3, 3)),
141145
param("list_dilation", 2, dilation=[3]),
142146
param("groups", 1, groups=3),
147+
param("stride", 1, stride=(2, 2)),
143148
]
144149
)
145150
def test_conv2d(

0 commit comments

Comments
 (0)