Skip to content

Commit da2bfbd

Browse files
authored
Merge pull request #307 from zhijxu-MS/tmp_branch_for_PR
fix bugs from ops
2 parents 9d0ee88 + bd7ae99 commit da2bfbd

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

tf2onnx/tfonnx.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -399,22 +399,17 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
399399
# kernel must to be transposed
400400
if with_kernel:
401401
parent = node.inputs[1]
402-
if node.inputs[1].is_const():
403-
# kernel is const - transpose the const
404-
if not parent.data_format:
405-
val = parent.get_tensor_value(as_list=False)
406-
val = val.transpose(HWCN_TO_NCHW)
407-
parent.set_tensor_value(val)
408-
else:
409-
# kernel comes from op, insert transpose op
410-
input_name = node.input[1]
411-
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
412-
transpose.set_attr("perm", HWCN_TO_NCHW)
413-
transpose.inserted_nchw = True
414-
ctx.copy_shape(input_name, transpose.output[0])
415-
new_shape = spatial_map(ctx.get_shape(input_name), HWCN_TO_NCHW)
416-
ctx.set_shape(transpose.output[0], new_shape)
417-
nodes.append(transpose)
402+
# note: kernel may be used by multiple nodes,
403+
# so even kernel is a const, transposing kernel can't be done statically.
404+
# so "transpose" op is inserted here and will consider to remove it in later optimization phase if possible.
405+
input_name = node.input[1]
406+
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
407+
transpose.set_attr("perm", HWCN_TO_NCHW)
408+
transpose.inserted_nchw = True
409+
ctx.copy_shape(input_name, transpose.output[0])
410+
new_shape = spatial_map(ctx.get_shape(input_name), HWCN_TO_NCHW)
411+
ctx.set_shape(transpose.output[0], new_shape)
412+
nodes.append(transpose)
418413
parent.data_format = "NCHW"
419414

420415
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
@@ -1227,6 +1222,12 @@ def minmax_op(ctx, node, name, args):
12271222

12281223

12291224
def pack_op(ctx, node, name, args):
1225+
# in tf, "pack" can accept one input tensor which means doing nothing,
1226+
# so remove the node in ONNX
1227+
if len(node.inputs) == 1:
1228+
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], node.input[0])
1229+
return None
1230+
12301231
# hack to make up for the missing onnx pack op
12311232
axis = node.get_attr("axis").i
12321233
if axis < 0:
@@ -1436,7 +1437,7 @@ def fill_op7(ctx, node, name, args):
14361437

14371438

14381439
def fill_op(ctx, node, name, args):
1439-
node.type = "ConstantLike"
1440+
node.type = "ConstantOfShape"
14401441
# both shape and value in tensorflow are passed as tensor.
14411442
# In onnx the value is an attribute so we need to fetch the value as const which
14421443
# sooner or later will be a problem for tensorflow-onnx.

0 commit comments

Comments
 (0)