Skip to content

Commit 64596bc

Browse files
committed
fix conv_convert_inputs
1 parent 9183796 commit 64596bc

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,10 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
5555
# transpose input if needed, no need to record shapes on input
5656
for idx in input_indices:
5757
parent = node.inputs[idx]
58-
if node.inputs[idx].is_const():
59-
# if input is a constant, transpose that one
60-
if not parent.data_format:
61-
val = parent.get_tensor_value(as_list=False)
62-
parent.set_tensor_value(val.transpose(constants.NHWC_TO_NCHW))
58+
if node.inputs[idx].is_const() and len(ctx.find_output_consumers(node.input[1])) == 1:
59+
# if input is a constant, transpose that one if we are the only consumer
60+
val = parent.get_tensor_value(as_list=False)
61+
parent.set_tensor_value(val.transpose(constants.NHWC_TO_NCHW))
6362
else:
6463
# if input comes from a op, insert transpose op
6564
input_name = node.input[idx]
@@ -70,33 +69,27 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
7069
if shape is not None:
7170
new_shape = spatial_map(shape, constants.NHWC_TO_NCHW)
7271
ctx.set_shape(transpose.output[0], new_shape)
73-
parent.data_format = "NCHW"
7472

7573
# kernel must to be transposed
7674
if with_kernel:
7775
parent = node.inputs[1]
7876
need_transpose = True
7977
if node.inputs[1].is_const():
8078
# kernel is const - transpose the const if we are the only consumer of const
81-
# TODO: maybe we should make a copy of the const, or look at the other consumers
82-
# if they'd want a transose as well.
8379
consumers = ctx.find_output_consumers(node.input[1])
8480
if len(consumers) == 1:
8581
val = parent.get_tensor_value(as_list=False)
8682
val = val.transpose(constants.HWCN_TO_NCHW)
8783
parent.set_tensor_value(val)
88-
parent.data_format = "NCHW"
8984
need_transpose = False
9085

9186
if need_transpose:
9287
input_name = node.input[1]
9388
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
9489
transpose.set_attr("perm", constants.HWCN_TO_NCHW)
9590
transpose.skip_conversion = True
96-
ctx.copy_shape(input_name, transpose.output[0])
9791
new_shape = spatial_map(ctx.get_shape(input_name), constants.HWCN_TO_NCHW)
9892
ctx.set_shape(transpose.output[0], new_shape)
99-
parent.data_format = "NCHW"
10093

10194
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
10295
if new_kernel_shape:
@@ -129,7 +122,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
129122
ctx.set_shape(transpose.output[0], output_shape)
130123
# Transpose TF NHWC shape back to NCHW shape for current ONNX conv node output
131124
ctx.set_shape(output_name, spatial_map(output_shape, constants.NHWC_TO_NCHW))
132-
node.data_format = "NCHW"
125+
node.data_format = "NCHW"
133126

134127

135128
def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):

0 commit comments

Comments
 (0)