Skip to content

Commit 271d260

Browse files
authored
Merge pull request #341 from onnx/gs/loader
fix const transpose
2 parents 5433ef6 + 6bf4715 commit 271d260

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

tf2onnx/graph.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(self, node, graph, skip_conversion=False):
3939
self._input = [i for i in node.input]
4040
self._output = [i for i in node.output]
4141
self._attr = {}
42-
self.inserted_nchw = False
4342

4443
graph.set_node_by_name(self)
4544
# dict to original attributes
@@ -1064,7 +1063,6 @@ def create_graph_from_onnx_graph(graph_proto):
10641063
all_nodes.extend(const_nodes_from_initializer)
10651064
g.set_nodes(all_nodes)
10661065

1067-
10681066
GraphUtil._parse_graph_input(g, graph_proto)
10691067

10701068
for n in g.get_nodes():

tf2onnx/loader.py

Whitespace-only changes.

tf2onnx/tfonnx.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def tensorflow_to_onnx(graph, shape_override):
146146

147147
def _convert_shapenode_to_int64(ctx, node, input_number):
148148
"""cast int32 shape into int64 shape."""
149-
shape_node = node.inputs[input_number]
150149
name = node.input[input_number]
151150

152151
cast_node = ctx.insert_new_node_on_input(node, "Cast", name)
@@ -382,7 +381,6 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
382381
input_name = node.input[idx]
383382
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
384383
transpose.set_attr("perm", NHWC_TO_NCHW)
385-
transpose.inserted_nchw = True
386384
transpose.skip_conversion = True
387385
shape = ctx.get_shape(input_name)
388386
new_shape = spatial_map(shape, NHWC_TO_NCHW)
@@ -393,19 +391,30 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
393391
# kernel must to be transposed
394392
if with_kernel:
395393
parent = node.inputs[1]
396-
# note: kernel may be used by multiple nodes,
397-
# so even kernel is a const, transposing kernel can't be done statically.
398-
# so "transpose" op is inserted here and will consider to remove it in later optimization phase if possible.
399-
input_name = node.input[1]
400-
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
401-
transpose.set_attr("perm", HWCN_TO_NCHW)
402-
transpose.inserted_nchw = True
403-
transpose.skip_conversion = True
404-
ctx.copy_shape(input_name, transpose.output[0])
405-
new_shape = spatial_map(ctx.get_shape(input_name), HWCN_TO_NCHW)
406-
ctx.set_shape(transpose.output[0], new_shape)
407-
nodes.append(transpose)
408-
parent.data_format = "NCHW"
394+
395+
need_transpose = True
396+
if node.inputs[1].is_const():
397+
# kernel is const - transpose the const if we are the only consumer of const
398+
# TODO: maybe we should make a copy of the const, or look at the other consumers
399+
# if they'd want a transose as well.
400+
consumers = ctx.find_output_consumers(node.input[1])
401+
if len(consumers) == 1:
402+
val = parent.get_tensor_value(as_list=False)
403+
val = val.transpose(HWCN_TO_NCHW)
404+
parent.set_tensor_value(val)
405+
parent.data_format = "NCHW"
406+
need_transpose = False
407+
408+
if need_transpose:
409+
input_name = node.input[1]
410+
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
411+
transpose.set_attr("perm", HWCN_TO_NCHW)
412+
transpose.skip_conversion = True
413+
ctx.copy_shape(input_name, transpose.output[0])
414+
new_shape = spatial_map(ctx.get_shape(input_name), HWCN_TO_NCHW)
415+
ctx.set_shape(transpose.output[0], new_shape)
416+
nodes.append(transpose)
417+
parent.data_format = "NCHW"
409418

410419
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
411420
if new_kernel_shape:
@@ -436,7 +445,6 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
436445
op_name = utils.make_name(node.name)
437446
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
438447
transpose.set_attr("perm", NCHW_TO_NHWC)
439-
transpose.inserted_nchw = True
440448
transpose.skip_conversion = True
441449
ctx.set_shape(transpose.output[0], ctx.get_shape(node.output[idx]))
442450
nodes.append(transpose)
@@ -2434,7 +2442,6 @@ def transpose_inputs(ctx, inputs_as_nchw):
24342442
op_name = utils.make_name(node.name)
24352443
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
24362444
transpose.set_attr("perm", NCHW_TO_NHWC)
2437-
transpose.inserted_nchw = True
24382445
ctx.copy_shape(output_name, transpose.output[0])
24392446
ctx.set_shape(output_name, np.array(shape)[NHWC_TO_NCHW])
24402447
ops.append(transpose)
@@ -2527,8 +2534,9 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
25272534
# check output existence in case user passed in wrong output ids
25282535
non_exists = set(io_to_check) - set(output_shapes.keys())
25292536
if non_exists:
2530-
log.error("\nFailed to convert: inputs/outputs specified do not exist, make sure your passed" \
2531-
" in format: input/output_node_name:port_id. Problematical inputs/outputs are: %s \n", non_exists)
2537+
log.error("\nFailed to convert: inputs/outputs specified do not exist, make sure your passed"
2538+
"in format: input/output_node_name:port_id. Problematical inputs/outputs are: %s \n",
2539+
non_exists)
25322540
raise ValueError("Inputs/Outputs Not Found")
25332541

25342542
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, output_names)

0 commit comments

Comments
 (0)