Skip to content

Commit 7dfb22d

Browse files
committed
Edge case: The output of a layer is both the input to another layer and also the output of the model.
1 parent c34ac1d commit 7dfb22d

File tree

4 files changed

+93
-20
lines changed

4 files changed

+93
-20
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
486486
self.contained_graphs = {} # {node_name: {node_attribute_name: Graph}}
487487

488488
ops = [Node(node, self) for node in nodes]
489+
489490
if input_names is not None:
490491
input_names_set = set(input_names)
491492
for n in ops:
@@ -740,7 +741,6 @@ def reset_nodes(self, ops):
740741

741742
if op.name in self.contained_graphs:
742743
remained_sub_graphs[op.name] = self.contained_graphs[op.name]
743-
744744
self._nodes = ops
745745
self.contained_graphs = remained_sub_graphs
746746
self._nodes_by_name = {op.name: op for op in ops}
@@ -758,7 +758,7 @@ def reset_nodes(self, ops):
758758
raise ValueError("graph input '" + n.name + "' not exist")
759759
for o in self.outputs:
760760
if o not in self._output_to_node_name:
761-
raise ValueError("graph output '" + o.name + "' not exist")
761+
raise ValueError("graph output '" + str(o) + "' not exist")
762762

763763
self._dtypes = remained_dtypes
764764
self._output_shapes = remained_shapes

tf2onnx/optimizer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
# optimizer sequence need to be considered carefully
2323
_optimizers = OrderedDict([
24-
("optimize_transpose", TransposeOptimizer),
24+
("remove_identity", IdentityOptimizer),
25+
("optimize_transpose", TransposeOptimizer), # transpose to reshape or add reshape
2526
("remove_redundant_upsample", UpsampleOptimizer),
2627
("fold_constants", ConstFoldOptimizer),
2728
("const_dequantize_optimizer", ConstDequantizeOptimizer),
@@ -32,7 +33,6 @@
3233
("reshape_optimizer", ReshapeOptimizer),
3334
("global_pool_optimizer", GlobalPoolOptimizer),
3435
("q_dq_optimizer", QDQOptimizer),
35-
("remove_identity", IdentityOptimizer),
3636
("remove_back_to_back", BackToBackOptimizer),
3737
("einsum_optimizer", EinsumOptimizer),
3838
])

tf2onnx/tfonnx.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
246246
exceptions = []
247247
if initialized_tables is None:
248248
initialized_tables = {}
249-
249+
250250
ops = list(g.get_nodes())
251251
for node in ops:
252252
logger.debug("Process node: %s\n%s", node.name, node.summary)
@@ -263,7 +263,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
263263
logger.error("Tensorflow op [%s: %s] is not supported", node.name, op)
264264
continue
265265
mapped_op[op] += 1
266-
266+
267267
func, kwargs = map_info
268268
if kwargs:
269269
# if there is a tf_op/onnx_op key we'll map the old type to a new type
@@ -273,6 +273,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
273273
kwargs["tfl_op" if is_tflite else "tf_op"] = op
274274
node.type = converted_op
275275
body_graphs = node.get_body_graphs()
276+
276277
if body_graphs:
277278
for attr, b_g in body_graphs.items():
278279
logger.debug("start handling subgraph of %s's attribute %s", node.name, attr)
@@ -287,7 +288,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
287288
b_g.topological_sort(b_g.get_nodes())
288289
exceptions.extend(body_exceptions)
289290
logger.debug("finish handling subgraph of %s's attribute %s", node.name, attr)
290-
291+
291292
try:
292293
func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
293294
if not is_tflite:
@@ -302,7 +303,6 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
302303
logger.error("Failed to convert node %r (fct=%r)\n%r",
303304
node.name, func, summary, exc_info=1)
304305
exceptions.append(ex)
305-
306306
return mapped_op, unmapped_op, exceptions
307307

308308

@@ -332,26 +332,96 @@ def transpose_inputs(ctx, inputs_as_nchw):
332332
def transpose_outputs(ctx, outputs_as_nchw):
333333
"""Insert a transpose from NHWC to NCHW on model output on users request."""
334334
ops = []
335+
336+
# First pass: Find and handle edge cases in original nodes
337+
edge_case_handled = set()
338+
335339
for node in ctx.get_nodes():
336340
for output_name in node.output:
337-
if output_name in outputs_as_nchw:
341+
# Check if this output is used to create a model output
342+
consumers = ctx.find_output_consumers(output_name)
343+
344+
# Look for edge case: output consumed by both model output node and other nodes
345+
model_output_consumers = []
346+
other_consumers = []
347+
348+
for consumer in consumers:
349+
if consumer.output and any(out in outputs_as_nchw for out in consumer.output):
350+
model_output_consumers.append(consumer)
351+
else:
352+
other_consumers.append(consumer)
353+
354+
# Edge case: original node output goes to both model output and other layers
355+
if model_output_consumers and other_consumers:
356+
# Get shape for validation
357+
shape = ctx.get_shape(output_name)
358+
if len(shape) != len(constants.NHWC_TO_NCHW):
359+
continue
360+
361+
# Handle edge case: Use insert_node_on_output for proper structure
362+
# Step 1: Create Identity node and insert it on the original output
363+
identity_name = utils.make_name(node.name + "_identity")
364+
identity = ctx.make_node("Identity", [output_name],
365+
outputs=[identity_name + ":0"], name=identity_name)
366+
367+
# Copy shape information
368+
ctx.copy_shape(output_name, identity.output[0])
369+
ctx.set_shape(identity.output[0], shape)
370+
371+
# Insert the identity on the original output - this will redirect ALL consumers
372+
ctx.insert_node_on_output(identity, output_name)
373+
374+
# Step 2: Create Transpose node and connect it to Identity
375+
transpose_name = utils.make_name(identity.name + "_transpose")
376+
transpose = ctx.make_node("Transpose", [identity.output[0]],
377+
outputs=[transpose_name + ":0"], name=transpose_name)
378+
transpose.set_attr("perm", constants.NHWC_TO_NCHW)
379+
ctx.copy_shape(identity.output[0], transpose.output[0])
380+
ctx.set_shape(transpose.output[0], np.array(shape)[constants.NHWC_TO_NCHW])
381+
382+
# Step 3: Manually redirect ONLY the model output consumers to use transpose
383+
for consumer in model_output_consumers:
384+
ctx.replace_all_inputs(identity.output[0], transpose.output[0], ops=[consumer])
385+
386+
# Mark this output as handled
387+
edge_case_handled.add(output_name)
388+
389+
ops.append(node)
390+
ops.append(identity)
391+
ops.append(transpose)
392+
break # Only handle one edge case per node
393+
394+
# If no edge case was handled for this node, add it normally
395+
if not any(out in edge_case_handled for out in node.output):
396+
ops.append(node)
397+
398+
# Second pass: Handle normal cases (nodes that directly output to model outputs)
399+
final_ops = []
400+
for node in ops:
401+
handled = False
402+
for output_name in node.output:
403+
if output_name in outputs_as_nchw and output_name not in edge_case_handled:
404+
# Get shape for validation
338405
shape = ctx.get_shape(output_name)
339406
if len(shape) != len(constants.NHWC_TO_NCHW):
340407
logger.warning("transpose_output for %s: shape must be rank 4, ignored" % output_name)
341-
ops.append(node)
342408
continue
409+
343410
# insert transpose
344411
op_name = utils.make_name(node.name)
345-
transpose = ctx.insert_new_node_on_output("Transpose", node.input[0], name=op_name)
412+
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
346413
transpose.set_attr("perm", constants.NHWC_TO_NCHW)
347-
ctx.copy_shape(node.output[0], transpose.output[0])
348-
ctx.set_shape(transpose.output[0], np.array(shape)[constants.NHWC_TO_NCHW])
414+
ctx.copy_shape(output_name, transpose.output[0])
349415
ctx.set_shape(output_name, np.array(shape)[constants.NHWC_TO_NCHW])
350-
ops.append(transpose)
351-
ops.append(node)
352-
continue
353-
ops.append(node)
354-
ctx.reset_nodes(ops)
416+
final_ops.append(transpose)
417+
final_ops.append(node)
418+
handled = True
419+
break
420+
421+
if not handled:
422+
final_ops.append(node)
423+
424+
ctx.reset_nodes(final_ops)
355425

356426
def topological_sort(g, continue_on_error):
357427
ops = g.get_nodes()
@@ -522,7 +592,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw,
522592
initialized_tables, is_tflite=False, dequantize=False):
523593

524594
op_cnt, attr_cnt = g.dump_node_statistics(include_attrs=True, include_subgraphs=False)
525-
595+
526596
if is_tflite:
527597
tfl_rewriters = []
528598
if dequantize:
@@ -531,13 +601,16 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw,
531601
tfl_rewriters.append(rewrite_tfl_select_zero)
532602
tfl_rewriters.append(rewrite_tfl_rfft)
533603
run_rewriters(g, tfl_rewriters, continue_on_error)
604+
534605
tfl_ops_mapping = handler.tfl_op.create_tfl_to_tf_mapping()
535606
_, _, exceptions = tensorflow_onnx_mapping(g, tfl_ops_mapping, is_tflite=True, dequantize=False)
607+
536608
if exceptions and not continue_on_error:
537609
raise exceptions[0]
538610

539611
# create ops mapping for the desired opsets
540612
ops_mapping = handler.tf_op.create_mapping(g.opset, g.extra_opset)
613+
541614

542615
# apply custom ops on top of the assembled opset. We can either complement the opset
543616
# or override existing ops with a custom op.

tf2onnx/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33

44
version = '1.16.1'
5-
git_version = '13bab8a91e17ccd87541b2f361ab60e8e38359d3'
5+
git_version = 'None'

0 commit comments

Comments
 (0)