Skip to content

Commit e46c061

Browse files
authored
[Relax][Onnx] Pass output_padding param in ConvTranspose (#18635)
### Summary Inconsistent shapes of results produced by TVM and ONNX View due to the ConvTranspose operator ### Steps to Reproduce - ONNX View: Output Shape = (1, 6, 56, 56) <img width="500" height="350" alt="Screenshots" src="https://github.com/user-attachments/assets/8a7129b4-9ebc-47a0-aa47-6060e8df2827" /> - TVM: Output Shape = (1, 6, 55, 55) (due to output_padding=[0, 0]) ``` class Module: def main(input: R.Tensor((1, 3, 28, 28), dtype="float32"), weight: R.Tensor((3, 6, 3, 3), dtype="float32"), bias: R.Tensor((6,), dtype="float32")) -> R.Tensor((1, 6, 55, 55), dtype="float32"): R.func_attr({"num_input": 1, "params": [metadata["ffi.Tensor"][0], metadata["ffi.Tensor"][1]]}) with R.dataflow(): lv: R.Tensor((1, 6, 55, 55), dtype="float32") = R.nn.conv2d_transpose(input, weight, strides=[2, 2], padding=[1, 1, 1, 1], output_padding=[0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="IOHW", out_layout="NCHW", out_dtype="void") lv1: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 6, 1, 1])) gv: R.Tensor((1, 6, 55, 55), dtype="float32") = R.add(lv, lv1) R.output(gv) return gv ``` ### Expected - output_padding = [1, 1] ``` class Module: def main(input: R.Tensor((1, 3, 28, 28), dtype="float32"), weight: R.Tensor((3, 6, 3, 3), dtype="float32"), bias: R.Tensor((6,), dtype="float32")) -> R.Tensor((1, 6, 56, 56), dtype="float32"): R.func_attr({"num_input": 1, "params": [metadata["ffi.Tensor"][0], metadata["ffi.Tensor"][1]]}) with R.dataflow(): lv: R.Tensor((1, 6, 56, 56), dtype="float32") = R.nn.conv2d_transpose(input, weight, strides=[2, 2], padding=[1, 1, 1, 1], output_padding=[1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="IOHW", out_layout="NCHW", out_dtype="void") lv1: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 6, 1, 1])) gv: R.Tensor((1, 6, 56, 56), dtype="float32") = R.add(lv, lv1) R.output(gv) return gv ``` ### Resolve - When implement converts an onnx ConvTranspose node into an equivalent Relax expression, get and pass output_padding param into op. - Fixed: #18601
1 parent 7b9d3d9 commit e46c061

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,7 @@ def _impl_v1(cls, bb, inputs, attr, params):
13341334
weight=inputs[1],
13351335
strides=attr.get("strides", 1),
13361336
padding=attr.get("pads", 0),
1337+
output_padding=attr.get("output_padding", 0),
13371338
dilation=attr.get("dilations", 1),
13381339
groups=attr.get("group", 1),
13391340
data_layout=data_layout,

src/relax/op/nn/convolution.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
786786
CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, "
787787
"the given number of groups is "
788788
<< groups;
789-
CHECK_EQ(output_padding.size(), 2) << "The input output_padding length is expected to be 4. "
789+
CHECK_EQ(output_padding.size(), 2) << "The input output_padding length is expected to be 2. "
790790
"However, the given output_padding is "
791791
<< output_padding;
792792
CHECK_EQ(strides.size(), 2)

tests/python/relax/test_frontend_onnx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,11 +1171,12 @@ def _verify_conv(input_shape, weight_shape):
11711171
_verify_conv([3, 4, 32, 32, 32], [2, 4, 3, 3, 3]) # group=2
11721172

11731173

1174-
@pytest.mark.parametrize("stride", [1, 2])
1174+
@pytest.mark.parametrize("stride", [2])
11751175
@pytest.mark.parametrize("dilation", [1])
11761176
@pytest.mark.parametrize("bias", [True, False])
11771177
@pytest.mark.parametrize("pad", [0, 2])
1178-
def test_conv_transpose(stride: int, dilation: int, pad: int, bias: bool):
1178+
@pytest.mark.parametrize("output_pad", [0, 1])
1179+
def test_conv_transpose(stride: int, dilation: int, pad: int, bias: bool, output_pad: int):
11791180
def _verify_conv_transpose(input_shape, weight_shape):
11801181
nd = len(weight_shape) - 2
11811182
output_shape = [input_shape[0], weight_shape[0]] + [
@@ -1190,6 +1191,7 @@ def _verify_conv_transpose(input_shape, weight_shape):
11901191
strides=[stride] * nd,
11911192
dilations=[dilation] * nd,
11921193
pads=[pad] * nd * 2,
1194+
output_padding=[output_pad] * nd,
11931195
group=input_shape[1] // weight_shape[1],
11941196
)
11951197
graph = helper.make_graph(

0 commit comments

Comments
 (0)