Commit e46c061
authored
[Relax][Onnx] Pass output_padding param in ConvTranspose (#18635)
### Summary
Inconsistent shapes of results produced by TVM and ONNX View due to the
ConvTranspose operator
### Steps to Reproduce
- ONNX View: Output Shape = (1, 6, 56, 56)
<img width="500" height="350" alt="Screenshots"
src="https://github.com/user-attachments/assets/8a7129b4-9ebc-47a0-aa47-6060e8df2827"
/>
- TVM: Output Shape = (1, 6, 55, 55) (due to output_padding=[0, 0])
```
class Module:
def main(input: R.Tensor((1, 3, 28, 28), dtype="float32"), weight: R.Tensor((3, 6, 3, 3), dtype="float32"), bias: R.Tensor((6,), dtype="float32")) -> R.Tensor((1, 6, 55, 55), dtype="float32"):
R.func_attr({"num_input": 1, "params": [metadata["ffi.Tensor"][0], metadata["ffi.Tensor"][1]]})
with R.dataflow():
lv: R.Tensor((1, 6, 55, 55), dtype="float32") = R.nn.conv2d_transpose(input, weight, strides=[2, 2], padding=[1, 1, 1, 1], output_padding=[0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="IOHW", out_layout="NCHW", out_dtype="void")
lv1: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 6, 1, 1]))
gv: R.Tensor((1, 6, 55, 55), dtype="float32") = R.add(lv, lv1)
R.output(gv)
return gv
```
### Expected
- output_padding = [1, 1]
```
class Module:
def main(input: R.Tensor((1, 3, 28, 28), dtype="float32"), weight: R.Tensor((3, 6, 3, 3), dtype="float32"), bias: R.Tensor((6,), dtype="float32")) -> R.Tensor((1, 6, 56, 56), dtype="float32"):
R.func_attr({"num_input": 1, "params": [metadata["ffi.Tensor"][0], metadata["ffi.Tensor"][1]]})
with R.dataflow():
lv: R.Tensor((1, 6, 56, 56), dtype="float32") = R.nn.conv2d_transpose(input, weight, strides=[2, 2], padding=[1, 1, 1, 1], output_padding=[1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="IOHW", out_layout="NCHW", out_dtype="void")
lv1: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 6, 1, 1]))
gv: R.Tensor((1, 6, 56, 56), dtype="float32") = R.add(lv, lv1)
R.output(gv)
return gv
```
### Resolve
- When implement converts an onnx ConvTranspose node into an equivalent
Relax expression, get and pass output_padding param into op.
- Fixed: #186011 parent 7b9d3d9 commit e46c061
File tree
3 files changed
+6
-3
lines changed- python/tvm/relax/frontend/onnx
- src/relax/op/nn
- tests/python/relax
3 files changed
+6
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1334 | 1334 | | |
1335 | 1335 | | |
1336 | 1336 | | |
| 1337 | + | |
1337 | 1338 | | |
1338 | 1339 | | |
1339 | 1340 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
786 | 786 | | |
787 | 787 | | |
788 | 788 | | |
789 | | - | |
| 789 | + | |
790 | 790 | | |
791 | 791 | | |
792 | 792 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1171 | 1171 | | |
1172 | 1172 | | |
1173 | 1173 | | |
1174 | | - | |
| 1174 | + | |
1175 | 1175 | | |
1176 | 1176 | | |
1177 | 1177 | | |
1178 | | - | |
| 1178 | + | |
| 1179 | + | |
1179 | 1180 | | |
1180 | 1181 | | |
1181 | 1182 | | |
| |||
1190 | 1191 | | |
1191 | 1192 | | |
1192 | 1193 | | |
| 1194 | + | |
1193 | 1195 | | |
1194 | 1196 | | |
1195 | 1197 | | |
| |||
0 commit comments