Skip to content

Commit bd7ae99

Browse files
committed
conv kernel may be used by multiple conv nodes, so we can't transpose it directly
1 parent 6e9885c commit bd7ae99

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

tf2onnx/tfonnx.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -422,22 +422,17 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
422422
# kernel must to be transposed
423423
if with_kernel:
424424
parent = node.inputs[1]
425-
if node.inputs[1].is_const():
426-
# kernel is const - transpose the const
427-
if not parent.data_format:
428-
val = parent.get_tensor_value(as_list=False)
429-
val = val.transpose(HWCN_TO_NCHW)
430-
parent.set_tensor_value(val)
431-
else:
432-
# kernel comes from op, insert transpose op
433-
input_name = node.input[1]
434-
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
435-
transpose.set_attr("perm", HWCN_TO_NCHW)
436-
transpose.inserted_nchw = True
437-
ctx.copy_shape(input_name, transpose.output[0])
438-
new_shape = spatial_map(ctx.get_shape(input_name), HWCN_TO_NCHW)
439-
ctx.set_shape(transpose.output[0], new_shape)
440-
nodes.append(transpose)
425+
# note: kernel may be used by multiple nodes,
426+
# so even kernel is a const, transposing kernel can't be done statically.
427+
# so "transpose" op is inserted here and will consider to remove it in later optimization phase if possible.
428+
input_name = node.input[1]
429+
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
430+
transpose.set_attr("perm", HWCN_TO_NCHW)
431+
transpose.inserted_nchw = True
432+
ctx.copy_shape(input_name, transpose.output[0])
433+
new_shape = spatial_map(ctx.get_shape(input_name), HWCN_TO_NCHW)
434+
ctx.set_shape(transpose.output[0], new_shape)
435+
nodes.append(transpose)
441436
parent.data_format = "NCHW"
442437

443438
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)

0 commit comments

Comments
 (0)