Skip to content

Commit 6566571

Browse files
authored
Merge pull request #82 from pengwa/fix-convtranspose
fix Conv2DBackpropInput conversion failure
2 parents 16480fb + abcfc82 commit 6566571

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tf2onnx/tfonnx.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def conv_dims_attr(node, name, new_name=None):
463463

464464

465465
def conv_kernel_shape(ctx, node, input_idx, spatial=2):
466-
kernel_shape = ctx.get_shape(node.input[1])
466+
kernel_shape = ctx.get_shape(node.input[input_idx])
467467
if len(kernel_shape) != 2 * spatial:
468468
raise ValueError("kernel rank must be 2* spatial")
469469
kernel_shape = kernel_shape[0:spatial]
@@ -492,6 +492,8 @@ def convtranspose_op(ctx, node, name, args):
492492

493493
# Note: inputs are reversed from what one would expect.
494494
kernel_shape = conv_kernel_shape(ctx, node, 1)
495+
496+
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
495497
output_shape = node.inputs[0].get_tensor_value()
496498
if node.is_nhwc():
497499
new_output_shape = [output_shape[1], output_shape[2]]
@@ -501,15 +503,17 @@ def convtranspose_op(ctx, node, name, args):
501503

502504
strides = conv_dims_attr(node, "strides")
503505
conv_dims_attr(node, "dilations")
504-
add_padding(ctx, node, kernel_shape, strides)
505506

506-
# remove output_shapes input, swap data and kernel
507+
# remove output_shapes input
507508
ctx.remove_input(node, node.input[0])
509+
# swap data and kernel
508510
t = node.input[0]
509511
node.input[0] = node.input[1]
510512
node.input[1] = t
511513

512514
nodes = conv_convert_inputs(ctx, node, with_kernel=True)
515+
516+
# Note: output_padding, group are left default.
513517
return nodes
514518

515519

0 commit comments

Comments
 (0)