Skip to content

Commit daeab0d

Browse files
Fix bug in ConvTranspose (#1542)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 2f9078e commit daeab0d

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,9 @@ def parse_dims_attr(node, dims, spatial):
263263
if len(dims) != spatial:
264264
dims = dims[1:-1]
265265
else:
266-
# We have (N, C, ...).
267-
dims = dims[2:]
266+
# We have (N, C, ...) or (...).
267+
if len(dims) != spatial:
268+
dims = dims[2:]
268269
return dims
269270

270271
def conv_dims_attr(node, name, new_name=None, spatial=2):
@@ -459,15 +460,11 @@ def version_1(cls, ctx, node, **kwargs):
459460
dilations = parse_dims_attr(node, node.get_attr("dilations").ints, spatial)
460461
else:
461462
dilations = [1] * spatial
462-
if "output_padding" in node.attr:
463-
output_padding = parse_dims_attr(node, node.get_attr("output_padding").ints, spatial)
464-
else:
465-
output_padding = [0] * spatial
466463
kernel_shape = parse_dims_attr(node, node.get_attr("kernel_shape").ints, spatial)
467464
total_padding = [-1] * spatial
468465
pads = [1] * (spatial * 2)
469466
for i in range(spatial):
470-
total_padding[i] = (strides[i] * (input_dims[i] - 1) + output_padding[i]
467+
total_padding[i] = (strides[i] * (input_dims[i] - 1)
471468
+ ((kernel_shape[i] - 1) * dilations[i] + 1)
472469
- new_output_shape[i])
473470
start_i = i

0 commit comments

Comments
 (0)