Skip to content

Commit 47acb8a

Browse files
Fix loop optimizer when scan output is also loop-carried-dependent (#1701)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 90f3689 commit 47acb8a

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

tf2onnx/optimizer/loop_optimizer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
"""Loop Optimizer.
5-
some op in loop's body graph can be moved out to the loop
5+
some op in loop's body graph can be moved out of the loop
66
"""
77

88
from tf2onnx.utils import make_name, make_sure
@@ -36,13 +36,15 @@ def _optimize_at_current_graph_level(self, g):
3636
return g
3737

3838
@staticmethod
39-
def consumer_nodes_num(graph, node):
39+
def num_consumers(graph, node):
4040
make_sure(len(node.output) == 1, "only consider node with only one output")
4141
res = len(graph.find_output_consumers(node.output[0]))
42+
# This is an optimizer so we cannot rely on outputs having Identity nodes
43+
res += graph.outputs.count(node.output[0])
4244
return res
4345

4446
def _try_move_transpose_out_of_body_graph(self, loop_node):
45-
# output node of body graph can be loop-carried-dependent, if so it can't be move out of the body graph
47+
# output node of body graph can be loop-carried-dependent, if so it can't be moved out of the body graph
4648
# return True if moving some nodes successfully
4749
# for now, we only consider moving transpose
4850
body_graph = loop_node.get_body_graphs()["body"]
@@ -54,14 +56,14 @@ def _try_move_transpose_out_of_body_graph(self, loop_node):
5456
# 1 delete node in body graph if possible
5557
# only consider two case: trans is output, or transpose > identity > output
5658
need_process = False
57-
if node.type == "Transpose" and self.consumer_nodes_num(body_graph, node) <= 1:
59+
if node.type == "Transpose" and self.num_consumers(body_graph, node) == 1:
5860
trans = node
5961
new_output = node.input[0]
6062
body_graph.remove_node(node.name)
6163
need_process = True
6264
elif node.type == "Identity" and node.inputs[0].type == "Transpose" \
63-
and self.consumer_nodes_num(body_graph, node) <= 1\
64-
and self.consumer_nodes_num(body_graph, node.inputs[0]) <= 1:
65+
and self.num_consumers(body_graph, node) == 1\
66+
and self.num_consumers(body_graph, node.inputs[0]) == 1:
6567
trans = node.inputs[0]
6668
new_output = node.inputs[0].input[0]
6769
body_graph.remove_node(node.inputs[0].name)

0 commit comments

Comments
 (0)