2
2
3
3
4
4
"""Loop Optimizer.
5
- some op in loop's body graph can be moved out to the loop
5
+ some op in loop's body graph can be moved out of the loop
6
6
"""
7
7
8
8
from tf2onnx .utils import make_name , make_sure
@@ -36,13 +36,15 @@ def _optimize_at_current_graph_level(self, g):
36
36
return g
37
37
38
38
@staticmethod
39
- def consumer_nodes_num (graph , node ):
39
+ def num_consumers (graph , node ):
40
40
make_sure (len (node .output ) == 1 , "only consider node with only one output" )
41
41
res = len (graph .find_output_consumers (node .output [0 ]))
42
+ # This is an optimizer so we cannot rely on outputs having Identity nodes
43
+ res += graph .outputs .count (node .output [0 ])
42
44
return res
43
45
44
46
def _try_move_transpose_out_of_body_graph (self , loop_node ):
45
- # output node of body graph can be loop-carried-dependent, if so it can't be move out of the body graph
47
+ # output node of body graph can be loop-carried-dependent, if so it can't be moved out of the body graph
46
48
# return True if moving some nodes successfully
47
49
# for now, we only consider moving transpose
48
50
body_graph = loop_node .get_body_graphs ()["body" ]
@@ -54,14 +56,14 @@ def _try_move_transpose_out_of_body_graph(self, loop_node):
54
56
# 1 delete node in body graph if possible
55
57
# only consider two case: trans is output, or transpose > identity > output
56
58
need_process = False
57
- if node .type == "Transpose" and self .consumer_nodes_num (body_graph , node ) < = 1 :
59
+ if node .type == "Transpose" and self .num_consumers (body_graph , node ) = = 1 :
58
60
trans = node
59
61
new_output = node .input [0 ]
60
62
body_graph .remove_node (node .name )
61
63
need_process = True
62
64
elif node .type == "Identity" and node .inputs [0 ].type == "Transpose" \
63
- and self .consumer_nodes_num (body_graph , node ) < = 1 \
64
- and self .consumer_nodes_num (body_graph , node .inputs [0 ]) < = 1 :
65
+ and self .num_consumers (body_graph , node ) = = 1 \
66
+ and self .num_consumers (body_graph , node .inputs [0 ]) = = 1 :
65
67
trans = node .inputs [0 ]
66
68
new_output = node .inputs [0 ].input [0 ]
67
69
body_graph .remove_node (node .inputs [0 ].name )
0 commit comments