Skip to content

Commit 20dae5d

Browse files
authored
Merge pull request #369 from nbcsm/conv
fix conv output shape
2 parents b15f792 + 58ca589 commit 20dae5d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tf2onnx/tfonnx.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,11 +430,15 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
430430
if node.is_nhwc():
431431
for idx in output_indices:
432432
output_name = node.output[idx]
433+
output_shape = ctx.get_shape(node.output[idx])
433434
op_name = utils.make_name(node.name)
434435
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
435436
transpose.set_attr("perm", NCHW_TO_NHWC)
436437
transpose.skip_conversion = True
437-
ctx.set_shape(transpose.output[0], ctx.get_shape(node.output[idx]))
438+
# set TF NHWC shape to transpose node output
439+
ctx.set_shape(transpose.output[0], output_shape)
440+
# Transpose TF NHWC shape back to NCHW shape for current ONNX conv node output
441+
ctx.set_shape(output_name, spatial_map(output_shape, NHWC_TO_NCHW))
438442
node.data_format = "NCHW"
439443

440444

0 commit comments

Comments
 (0)