Skip to content

Commit c68e2c5

Browse files
committed
add loop_optimizer
1 parent da2d7a6 commit c68e2c5

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

tf2onnx/optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
from .identity_optimizer import IdentityOptimizer
1414
from .merge_duplicated_nodes_optimizer import MergeDuplicatedNodesOptimizer
1515
from .transpose_optimizer import TransposeOptimizer
16+
from .loop_optimizer import LoopOptimizer
1617
from .. import logging
1718

1819
# optimizer sequence need to be considered carefully
1920
_optimizers = OrderedDict([
2021
("optimize_transpose", TransposeOptimizer),
2122
("fold_constants", ConstFoldOptimizer),
23+
("loop_optimizer", LoopOptimizer),
2224
# merge_duplication should be used after optimize_transpose
2325
# for optimize_transpose may have some trans nodes that can be merge
2426
("merge_duplication", MergeDuplicatedNodesOptimizer),

tf2onnx/optimizer/loop_optimizer.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Loop Optimizer.
5+
some op in loop's body graph can be moved out to the loop
6+
"""
7+
8+
from tf2onnx.utils import make_name, make_sure
9+
from .optimizer_base import GraphOptimizerBase
10+
11+
12+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
13+
14+
15+
class LoopOptimizer(GraphOptimizerBase):
16+
"""Loop Optimizer."""
17+
18+
# a lot of terms used here come from loop's onnx spec
19+
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop
20+
def __init__(self): # pylint: disable=useless-super-delegation
21+
super(LoopOptimizer, self).__init__()
22+
23+
def _optimize(self, graph):
24+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
25+
26+
def _optimize_at_current_graph_level(self, g):
27+
has_update = True
28+
while has_update:
29+
has_update = False
30+
nodes = [n for n in g.get_nodes() if n.type == "Loop"]
31+
for n in nodes:
32+
has_update_tmp = self._try_move_transpose_out_of_body_graph(n)
33+
if has_update_tmp:
34+
has_update = True
35+
return g
36+
37+
@staticmethod
38+
def consumer_nodes_num(graph, node):
39+
make_sure(len(node.output) == 1, "only consider node with only one output")
40+
res = len(graph.find_output_consumers(node.output[0]))
41+
return res
42+
43+
def _try_move_transpose_out_of_body_graph(self, loop_node):
44+
# output node of body graph can be loop-carried-dependent, if so it can't be move out of the body graph
45+
# return True if moving some nodes successfully
46+
# for now, we only consider moving transpose
47+
body_graph = loop_node.get_body_graphs()["body"]
48+
parent_graph = loop_node.graph
49+
scan_nodes_name_in_body, scan_node_in_parent = self._scan_outputs(loop_node)
50+
scan_nodes = [body_graph.get_node_by_output(name) for name in scan_nodes_name_in_body]
51+
graph_is_changed = False
52+
for node, name_in_parent in zip(scan_nodes, scan_node_in_parent):
53+
# 1 delete node in body graph if possible
54+
# only consider two case: trans is output, or transpose > identity > output
55+
need_process = False
56+
if node.type == "Transpose" and self.consumer_nodes_num(body_graph, node) <= 1:
57+
trans = node
58+
new_output = node.input[0]
59+
body_graph.remove_node(node.name)
60+
need_process = True
61+
elif node.type == "Identity" and node.inputs[0].type == "Transpose" \
62+
and self.consumer_nodes_num(body_graph, node) <= 1\
63+
and self.consumer_nodes_num(body_graph, node.inputs[0]) <= 1:
64+
trans = node.inputs[0]
65+
new_output = node.inputs[0].input[0]
66+
body_graph.remove_node(node.inputs[0].name)
67+
body_graph.remove_node(node.name)
68+
need_process = True
69+
70+
if need_process:
71+
# 2 correct body graph's output
72+
body_outputs = body_graph.outputs
73+
body_outputs[body_outputs.index(node.output[0])] = new_output
74+
# 3 insert new node in parent graph
75+
ori_perm = list(trans.get_attr("perm").ints)
76+
new_perm = [0] + [i + 1 for i in ori_perm] # body output's rank is m > rank of loop's output is m+1
77+
name = make_name("trans_moved_from_loop_body")
78+
_ = parent_graph.insert_new_node_on_output("Transpose", name_in_parent, name, perm=new_perm)
79+
graph_is_changed = True
80+
81+
return graph_is_changed
82+
83+
@classmethod
84+
def _scan_outputs(cls, loop):
85+
# loop has 2+N inputs; loop has N+K outputs;
86+
# loop's body graph has 1+N+K outputs
87+
loop_carried = len(loop.input) - 2
88+
body_graph = loop.get_body_graphs()["body"]
89+
return body_graph.outputs[loop_carried + 1:], loop.output[loop_carried:]

0 commit comments

Comments
 (0)