Skip to content

Commit 93648c4

Browse files
authored
Merge pull request #584 from zhijxu-MS/add_loop_optimizer
enhance optimizer
2 parents b23d3cb + 2b477f9 commit 93648c4

File tree

8 files changed

+185
-8
lines changed

8 files changed

+185
-8
lines changed

tests/test_optimizers.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
5050
self.assertEqual(expected_val.shape, actual_val.shape)
5151

5252
return new_proto
53+
54+
@staticmethod
55+
def _make_onnx_const(np_val, output_name):
56+
node = helper.make_node(
57+
'Constant',
58+
inputs=[],
59+
outputs=[output_name],
60+
value=helper.make_tensor(
61+
name=output_name,
62+
data_type=utils.map_numpy_to_onnx_dtype(np_val.dtype),
63+
dims=np_val.shape,
64+
vals=np_val.flatten().astype(np_val.dtype),
65+
),
66+
)
67+
return node
5368
# Tranpose Optimizer Tests Start
5469

5570
def run_transpose_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
@@ -304,6 +319,55 @@ def test_transpose_with_squeeze4(self):
304319
self.run_transpose_compare(["Z"], {"X": np.random.randn(3, 1, 1, 5).astype(np.float32)},
305320
model_proto, remaining_transpose_num=0)
306321

322+
def test_transpose_with_loop(self):
323+
def _define_loop_graph(external_inputs):
324+
# external_inputs: external node which will be used by this graph
325+
# graph without loop carried
326+
# computation
327+
# for(...){a = external_inputs[i]; b = trans(a), c = squeeze(b)}, c is scan output
328+
node1 = helper.make_node("Gather", [external_inputs[0], "loop_iter_num"], ["Y0"])
329+
node2 = helper.make_node("Transpose", ["Y0"], ["Z0"], perm=[0, 2, 3, 1])
330+
# graph output
331+
node3 = helper.make_node("Squeeze", ["Z0"], ["scan_output"], axes=[0])
332+
node4 = helper.make_node("Identity", ["loop_condition"], ["loop_cond_output"])
333+
node5 = helper.make_node("Identity", ["loop_condition"], ["loop_carried_output"])
334+
335+
graph = helper.make_graph(
336+
[node1, node2, node3, node4, node5],
337+
"loop_subgraph",
338+
[helper.make_tensor_value_info("loop_iter_num", TensorProto.INT64, (1,)), # iteration_num
339+
helper.make_tensor_value_info("loop_condition", TensorProto.BOOL, ()), # condition
340+
helper.make_tensor_value_info("loop_carried", TensorProto.BOOL, ()) # loop_carried
341+
],
342+
[helper.make_tensor_value_info("loop_cond_output", TensorProto.BOOL, ()),
343+
helper.make_tensor_value_info("loop_carried_output", TensorProto.BOOL, ()),
344+
helper.make_tensor_value_info("scan_output", TensorProto.FLOAT, ["unknown"] * 3)
345+
],
346+
)
347+
return graph
348+
349+
def _make_loop(external_inputs, outputs):
350+
trip_cnt = self._make_onnx_const(np.array(10, dtype=np.int64), "trip_cnt")
351+
cond = self._make_onnx_const(np.array(True, dtype=np.bool), "cond")
352+
sub_graph = _define_loop_graph(external_inputs)
353+
loop_node = helper.make_node("Loop", ["trip_cnt", "cond", "cond"], outputs,
354+
name="loop", body=sub_graph)
355+
return trip_cnt, cond, loop_node
356+
357+
nodes = _make_loop(["array"], ["loop_carried", "scan_out"])
358+
res = helper.make_node("Transpose", ["scan_out"], ["Y"], perm=[0, 3, 1, 2], name="trans")
359+
360+
graph = helper.make_graph(
361+
[*nodes, res],
362+
"transpose_with_loop",
363+
[helper.make_tensor_value_info("array", TensorProto.FLOAT, ["unknow"] * 4)],
364+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["unknow"] * 4)],
365+
)
366+
367+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
368+
self.run_transpose_compare(["Y"], {"array": np.random.randn(10, 3, 4, 5).astype(np.float32)},
369+
model_proto, remaining_transpose_num=0)
370+
307371
def test_trans_output_as_graph_outputs(self):
308372
"""
309373
If transpose's output is graph's output, don't optimize it.

tf2onnx/optimizer/__init__.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
from .identity_optimizer import IdentityOptimizer
1414
from .merge_duplicated_nodes_optimizer import MergeDuplicatedNodesOptimizer
1515
from .transpose_optimizer import TransposeOptimizer
16+
from .loop_optimizer import LoopOptimizer
1617
from .. import logging
1718

1819
# optimizer sequence need to be considered carefully
1920
_optimizers = OrderedDict([
2021
("optimize_transpose", TransposeOptimizer),
2122
("fold_constants", ConstFoldOptimizer),
23+
("loop_optimizer", LoopOptimizer),
2224
# merge_duplication should be used after optimize_transpose
2325
# for optimize_transpose may have some trans nodes that can be merge
2426
("merge_duplication", MergeDuplicatedNodesOptimizer),
@@ -37,14 +39,20 @@ def optimize_graph(graph):
3739

3840
before = graph.dump_node_statistics()
3941
opts = _get_optimizers()
40-
for name, factory in opts.items():
41-
try:
42-
logger.verbose("Apply %s", name)
43-
current = copy.deepcopy(graph)
44-
graph = factory().optimize(current)
45-
except Exception: # pylint: disable=broad-except
46-
# if current optimizer fails, continue with other optimizers
47-
logger.warning("Failed to apply %s", name, exc_info=1)
42+
continue_flag = True
43+
while continue_flag:
44+
continue_flag = False
45+
for name, factory in opts.items():
46+
try:
47+
logger.verbose("Apply %s", name)
48+
current = copy.deepcopy(graph)
49+
opt = factory()
50+
graph = opt.optimize(current)
51+
continue_flag = continue_flag or opt.graph_been_opt
52+
53+
except Exception: # pylint: disable=broad-except
54+
# if current optimizer fails, continue with other optimizers
55+
logger.warning("Failed to apply %s", name, exc_info=1)
4856

4957
after = graph.dump_node_statistics()
5058
diff = copy.deepcopy(after)

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _optimize_at_current_graph_level(self, graph):
4242
continue
4343
if self._fold_node(op, graph):
4444
graph_changed = True
45+
self.graph_been_opt = True
4546
return graph
4647

4748
@staticmethod

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def _optimize_at_current_graph_level(self, g):
3939
else:
4040
ret = self._handle_non_graph_output_identity(g, n)
4141
has_update = ret
42+
if ret:
43+
self.graph_been_opt = True
4244
return g
4345

4446
@staticmethod

tf2onnx/optimizer/loop_optimizer.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Loop Optimizer.
5+
some op in loop's body graph can be moved out to the loop
6+
"""
7+
8+
from tf2onnx.utils import make_name, make_sure
9+
from .optimizer_base import GraphOptimizerBase
10+
11+
12+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
13+
14+
15+
class LoopOptimizer(GraphOptimizerBase):
16+
"""Loop Optimizer."""
17+
18+
# a lot of terms used here come from loop's onnx spec
19+
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop
20+
def __init__(self): # pylint: disable=useless-super-delegation
21+
super(LoopOptimizer, self).__init__()
22+
23+
def _optimize(self, graph):
24+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
25+
26+
def _optimize_at_current_graph_level(self, g):
27+
has_update = True
28+
while has_update:
29+
has_update = False
30+
nodes = [n for n in g.get_nodes() if n.type == "Loop"]
31+
for n in nodes:
32+
has_update_tmp = self._try_move_transpose_out_of_body_graph(n)
33+
if has_update_tmp:
34+
has_update = True
35+
self.graph_been_opt = True
36+
return g
37+
38+
@staticmethod
39+
def consumer_nodes_num(graph, node):
40+
make_sure(len(node.output) == 1, "only consider node with only one output")
41+
res = len(graph.find_output_consumers(node.output[0]))
42+
return res
43+
44+
def _try_move_transpose_out_of_body_graph(self, loop_node):
45+
# output node of body graph can be loop-carried-dependent, if so it can't be move out of the body graph
46+
# return True if moving some nodes successfully
47+
# for now, we only consider moving transpose
48+
body_graph = loop_node.get_body_graphs()["body"]
49+
parent_graph = loop_node.graph
50+
scan_nodes_name_in_body, scan_node_in_parent = self._scan_outputs(loop_node)
51+
scan_nodes = [body_graph.get_node_by_output(name) for name in scan_nodes_name_in_body]
52+
graph_is_changed = False
53+
for node, name_in_parent in zip(scan_nodes, scan_node_in_parent):
54+
# 1 delete node in body graph if possible
55+
# only consider two case: trans is output, or transpose > identity > output
56+
need_process = False
57+
if node.type == "Transpose" and self.consumer_nodes_num(body_graph, node) <= 1:
58+
trans = node
59+
new_output = node.input[0]
60+
body_graph.remove_node(node.name)
61+
need_process = True
62+
elif node.type == "Identity" and node.inputs[0].type == "Transpose" \
63+
and self.consumer_nodes_num(body_graph, node) <= 1\
64+
and self.consumer_nodes_num(body_graph, node.inputs[0]) <= 1:
65+
trans = node.inputs[0]
66+
new_output = node.inputs[0].input[0]
67+
body_graph.remove_node(node.inputs[0].name)
68+
body_graph.remove_node(node.name)
69+
need_process = True
70+
71+
if need_process:
72+
# 2 correct body graph's output
73+
body_outputs = body_graph.outputs
74+
body_outputs[body_outputs.index(node.output[0])] = new_output
75+
# 3 insert new node in parent graph
76+
ori_perm = list(trans.get_attr("perm").ints)
77+
new_perm = [0] + [i + 1 for i in ori_perm] # body output's rank is m > rank of loop's output is m+1
78+
name = make_name("trans_moved_from_loop_body")
79+
_ = parent_graph.insert_new_node_on_output("Transpose", name_in_parent, name, perm=new_perm)
80+
graph_is_changed = True
81+
82+
return graph_is_changed
83+
84+
@classmethod
85+
def _scan_outputs(cls, loop):
86+
# loop has 2+N inputs; loop has N+K outputs;
87+
# loop's body graph has 1+N+K outputs
88+
loop_carried = len(loop.input) - 2
89+
body_graph = loop.get_body_graphs()["body"]
90+
return body_graph.outputs[loop_carried + 1:], loop.output[loop_carried:]

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def _optimize_at_current_graph_level(self, graph):
3232
while self._graph_can_be_optimized:
3333
self._graph_can_be_optimized = False
3434
self._merge_duplicated_nodes(graph)
35+
if self._graph_can_be_optimized:
36+
self.graph_been_opt = True
3537
return graph
3638

3739
def _merge_duplicated_nodes(self, graph):

tf2onnx/optimizer/optimizer_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class GraphOptimizerBase(object):
1616

1717
def __init__(self):
1818
self._logger = logging.getLogger('.'.join(__name__.split('.')[:-1] + [self.__class__.__name__]))
19+
self._graph_been_opt = False
1920

2021
@property
2122
def logger(self):
@@ -25,6 +26,14 @@ def logger(self):
2526
def is_debug_mode(self):
2627
return utils.is_debug_mode()
2728

29+
@property
30+
def graph_been_opt(self):
31+
return self._graph_been_opt
32+
33+
@graph_been_opt.setter
34+
def graph_been_opt(self, value):
35+
self._graph_been_opt = value
36+
2837
def optimize(self, graph):
2938
""" Optimize graph, return optimized graph. """
3039
before = graph.dump_node_statistics()

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _optimize_at_current_graph_level(self, graph):
147147
if is_nhwc_transpose(n):
148148
if self._handle_nhwc_tranpose(n):
149149
no_action = False
150+
self.graph_been_opt = True
150151
iteration_cnt += 1
151152
# need break, because handler may change nodes set, making the n stale object
152153
# referencing already deleted elements

0 commit comments

Comments
 (0)