Skip to content

Fix Edge Case: Shared Layer Output Between Model Output and Internal Layers #2407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
self.contained_graphs = {} # {node_name: {node_attribute_name: Graph}}

ops = [Node(node, self) for node in nodes]

if input_names is not None:
input_names_set = set(input_names)
for n in ops:
Expand Down Expand Up @@ -740,7 +741,6 @@ def reset_nodes(self, ops):

if op.name in self.contained_graphs:
remained_sub_graphs[op.name] = self.contained_graphs[op.name]

self._nodes = ops
self.contained_graphs = remained_sub_graphs
self._nodes_by_name = {op.name: op for op in ops}
Expand All @@ -758,7 +758,7 @@ def reset_nodes(self, ops):
raise ValueError("graph input '" + n.name + "' not exist")
for o in self.outputs:
if o not in self._output_to_node_name:
raise ValueError("graph output '" + o.name + "' not exist")
raise ValueError("graph output '" + str(o) + "' not exist")

self._dtypes = remained_dtypes
self._output_shapes = remained_shapes
Expand Down
4 changes: 2 additions & 2 deletions tf2onnx/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

# optimizer sequence need to be considered carefully
_optimizers = OrderedDict([
("optimize_transpose", TransposeOptimizer),
("remove_identity", IdentityOptimizer),
("optimize_transpose", TransposeOptimizer), # transpose to reshape or add reshape
("remove_redundant_upsample", UpsampleOptimizer),
("fold_constants", ConstFoldOptimizer),
("const_dequantize_optimizer", ConstDequantizeOptimizer),
Expand All @@ -32,7 +33,6 @@
("reshape_optimizer", ReshapeOptimizer),
("global_pool_optimizer", GlobalPoolOptimizer),
("q_dq_optimizer", QDQOptimizer),
("remove_identity", IdentityOptimizer),
("remove_back_to_back", BackToBackOptimizer),
("einsum_optimizer", EinsumOptimizer),
])
Expand Down
103 changes: 88 additions & 15 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
exceptions = []
if initialized_tables is None:
initialized_tables = {}

ops = list(g.get_nodes())
for node in ops:
logger.debug("Process node: %s\n%s", node.name, node.summary)
Expand All @@ -263,7 +263,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
logger.error("Tensorflow op [%s: %s] is not supported", node.name, op)
continue
mapped_op[op] += 1

func, kwargs = map_info
if kwargs:
# if there is a tf_op/onnx_op key we'll map the old type to a new type
Expand All @@ -273,6 +273,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
kwargs["tfl_op" if is_tflite else "tf_op"] = op
node.type = converted_op
body_graphs = node.get_body_graphs()

if body_graphs:
for attr, b_g in body_graphs.items():
logger.debug("start handling subgraph of %s's attribute %s", node.name, attr)
Expand All @@ -287,7 +288,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
b_g.topological_sort(b_g.get_nodes())
exceptions.extend(body_exceptions)
logger.debug("finish handling subgraph of %s's attribute %s", node.name, attr)

try:
func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
if not is_tflite:
Expand All @@ -302,7 +303,6 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
logger.error("Failed to convert node %r (fct=%r)\n%r",
node.name, func, summary, exc_info=1)
exceptions.append(ex)

return mapped_op, unmapped_op, exceptions


Expand Down Expand Up @@ -332,26 +332,96 @@ def transpose_inputs(ctx, inputs_as_nchw):
def transpose_outputs(ctx, outputs_as_nchw):
"""Insert a transpose from NHWC to NCHW on model output on users request."""
ops = []

# First pass: Find and handle edge cases in original nodes
edge_case_handled = set()

for node in ctx.get_nodes():
for output_name in node.output:
if output_name in outputs_as_nchw:
# Check if this output is used to create a model output
consumers = ctx.find_output_consumers(output_name)

# Look for edge case: output consumed by both model output node and other nodes
model_output_consumers = []
other_consumers = []

for consumer in consumers:
if consumer.output and any(out in outputs_as_nchw for out in consumer.output):
model_output_consumers.append(consumer)
else:
other_consumers.append(consumer)

# Edge case: original node output goes to both model output and other layers
if model_output_consumers and other_consumers:
# Get shape for validation
shape = ctx.get_shape(output_name)
if len(shape) != len(constants.NHWC_TO_NCHW):
continue

# Handle edge case: Use insert_node_on_output for proper structure
# Step 1: Create Identity node and insert it on the original output
identity_name = utils.make_name(node.name + "_identity")
identity = ctx.make_node("Identity", [output_name],
outputs=[identity_name + ":0"], name=identity_name)

# Copy shape information
ctx.copy_shape(output_name, identity.output[0])
ctx.set_shape(identity.output[0], shape)

# Insert the identity on the original output - this will redirect ALL consumers
ctx.insert_node_on_output(identity, output_name)

# Step 2: Create Transpose node and connect it to Identity
transpose_name = utils.make_name(identity.name + "_transpose")
transpose = ctx.make_node("Transpose", [identity.output[0]],
outputs=[transpose_name + ":0"], name=transpose_name)
transpose.set_attr("perm", constants.NHWC_TO_NCHW)
ctx.copy_shape(identity.output[0], transpose.output[0])
ctx.set_shape(transpose.output[0], np.array(shape)[constants.NHWC_TO_NCHW])

# Step 3: Manually redirect ONLY the model output consumers to use transpose
for consumer in model_output_consumers:
ctx.replace_all_inputs(identity.output[0], transpose.output[0], ops=[consumer])

# Mark this output as handled
edge_case_handled.add(output_name)

ops.append(node)
ops.append(identity)
ops.append(transpose)
break # Only handle one edge case per node

# If no edge case was handled for this node, add it normally
if not any(out in edge_case_handled for out in node.output):
ops.append(node)

# Second pass: Handle normal cases (nodes that directly output to model outputs)
final_ops = []
for node in ops:
handled = False
for output_name in node.output:
if output_name in outputs_as_nchw and output_name not in edge_case_handled:
# Get shape for validation
shape = ctx.get_shape(output_name)
if len(shape) != len(constants.NHWC_TO_NCHW):
logger.warning("transpose_output for %s: shape must be rank 4, ignored" % output_name)
ops.append(node)
continue

# insert transpose
op_name = utils.make_name(node.name)
transpose = ctx.insert_new_node_on_output("Transpose", node.input[0], name=op_name)
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
transpose.set_attr("perm", constants.NHWC_TO_NCHW)
ctx.copy_shape(node.output[0], transpose.output[0])
ctx.set_shape(transpose.output[0], np.array(shape)[constants.NHWC_TO_NCHW])
ctx.copy_shape(output_name, transpose.output[0])
ctx.set_shape(output_name, np.array(shape)[constants.NHWC_TO_NCHW])
ops.append(transpose)
ops.append(node)
continue
ops.append(node)
ctx.reset_nodes(ops)
final_ops.append(transpose)
final_ops.append(node)
handled = True
break

if not handled:
final_ops.append(node)

ctx.reset_nodes(final_ops)

def topological_sort(g, continue_on_error):
ops = g.get_nodes()
Expand Down Expand Up @@ -522,7 +592,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw,
initialized_tables, is_tflite=False, dequantize=False):

op_cnt, attr_cnt = g.dump_node_statistics(include_attrs=True, include_subgraphs=False)

if is_tflite:
tfl_rewriters = []
if dequantize:
Expand All @@ -531,13 +601,16 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw,
tfl_rewriters.append(rewrite_tfl_select_zero)
tfl_rewriters.append(rewrite_tfl_rfft)
run_rewriters(g, tfl_rewriters, continue_on_error)

tfl_ops_mapping = handler.tfl_op.create_tfl_to_tf_mapping()
_, _, exceptions = tensorflow_onnx_mapping(g, tfl_ops_mapping, is_tflite=True, dequantize=False)

if exceptions and not continue_on_error:
raise exceptions[0]

# create ops mapping for the desired opsets
ops_mapping = handler.tf_op.create_mapping(g.opset, g.extra_opset)


# apply custom ops on top of the assembled opset. We can either complement the opset
# or override existing ops with a custom op.
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


version = '1.16.1'
git_version = '13bab8a91e17ccd87541b2f361ab60e8e38359d3'
git_version = 'None'