Skip to content

Commit 444923f

Browse files
committed
add graph optimizer - merge duplicated nodes
1 parent eb31fc3 commit 444923f

File tree

9 files changed

+265
-64
lines changed

9 files changed

+265
-64
lines changed

tests/run_pretrained_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
270270
onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override,
271271
input_names=inputs.keys())
272272
model_proto = onnx_graph.make_model("converted from tf2onnx")
273-
new_model_proto = GraphUtil.optimize_graph(onnx_graph, "test", debug=debug)
273+
new_model_proto = GraphUtil.optimize_graph(onnx_graph, debug=debug).make_model("optimized")
274274
if new_model_proto:
275275
model_proto = new_model_proto
276276
else:

tests/test_optimizers.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
2727

2828
origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin")
2929

30-
new_proto = GraphUtil.optimize_graph_with_model_proto(origin_proto)
30+
new_proto = GraphUtil.optimize_model_proto(origin_proto)
3131

3232
self.assertTrue(new_proto, msg="model proto after optimizer should not be None")
3333

@@ -287,7 +287,54 @@ def test_identity_in_subgraph_non_graph_output(self):
287287
self.run_identity_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
288288
model_proto, remaining_identity_num=0)
289289

290-
# Tranpose Optimizer Tests End
290+
# Identity Optimizer Tests End
291+
292+
# Merge Duplicated Nodes Optimizer Tests Start
293+
294+
def run_merge_duplicated_nodes_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
295+
op_type=None, remaining_op_num=None, debug=False, rtol=1e-07):
296+
self.run_and_compare(output_names_with_port, onnx_feed_dict, origin_proto, op_type=op_type,
297+
remaining_op_num=remaining_op_num, debug=debug, rtol=rtol)
298+
299+
def test_duplicated_duplicated_input(self):
300+
# same input or not
301+
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])
302+
node1 = helper.make_node('Add', inputs=["X", "X"], outputs=["value1"])
303+
node2 = helper.make_node('Add', inputs=["value1", "X"], outputs=["value2"])
304+
node3 = helper.make_node("Mul", ["value0", "value2"], ["value3"])
305+
node4 = helper.make_node("Mul", ["value1", "value3"], ["OUT"])
306+
307+
graph = helper.make_graph(
308+
[node0, node1, node2, node3, node4],
309+
"transpose-merge-test",
310+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5, 5))],
311+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (5, 5))],
312+
)
313+
314+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
315+
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
316+
op_type="Add", remaining_op_num=2)
317+
318+
def test_duplicated_duplicated_attributes(self):
319+
# same attr or not
320+
node0 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value0"], axes=[0], keepdims=0)
321+
node1 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value1"], axes=[0], keepdims=0)
322+
node2 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value2"], axes=[1], keepdims=0)
323+
node3 = helper.make_node('Add', inputs=["value0", "value1"], outputs=["value3"])
324+
node4 = helper.make_node("Mul", ["value2", "value3"], ["OUT"])
325+
326+
graph = helper.make_graph(
327+
[node0, node1, node2, node3, node4],
328+
"transpose-merge-test",
329+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5, 5))],
330+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (5,))],
331+
)
332+
333+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
334+
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
335+
op_type="ReduceSum", remaining_op_num=2)
336+
# Merge Duplicated Nodes Optimizer Tests End
337+
291338

292339
if __name__ == "__main__":
293340
unittest_main()

tf2onnx/convert.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,9 @@ def main():
117117
output_names=outputs,
118118
inputs_as_nchw=args.inputs_as_nchw)
119119

120-
model_proto = g.make_model("converted from {}".format(args.input))
120+
model_proto = g.make_model("converted from {}".format(model_path))
121121

122-
new_model_proto = GraphUtil.optimize_graph(g, "converted from {}".format(model_path),
123-
optimize=not args.continue_on_error)
122+
new_model_proto = GraphUtil.optimize_model_proto(model_proto)
124123
if new_model_proto:
125124
model_proto = new_model_proto
126125
else:

tf2onnx/graph.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import six
1818
import numpy as np
1919

20-
from onnx import helper, numpy_helper, optimizer, shape_inference, OperatorSetIdProto, AttributeProto
20+
from onnx import helper, numpy_helper, shape_inference, OperatorSetIdProto, AttributeProto
2121
from tf2onnx import utils, __version__
2222
from tf2onnx.utils import port_name, find_opset
23-
from tf2onnx.optimizer import IdentityOptimizer, TransposeOptimizer
23+
from tf2onnx import optimizer
2424
from tf2onnx.schemas import get_schema
2525

2626
logging.basicConfig(level=logging.INFO)
@@ -1042,30 +1042,11 @@ class GraphUtil(object):
10421042
"""Utilities for Graph manipulation."""
10431043

10441044
@staticmethod
1045-
def optimize_graph(graph, doc_string, optimize=None, debug=False):
1046-
"""Optimize the graph, for example: eliminating all useless Transpose/Identity pairs.
1047-
1048-
Returns:
1049-
model proto after optimization, if optimizer run successfully
1050-
or None, if exceptions happen
1051-
"""
1052-
try:
1053-
opts = [TransposeOptimizer(graph, output_names=graph.outputs, debug=debug),
1054-
IdentityOptimizer(graph, output_names=graph.outputs, debug=debug)
1055-
]
1056-
for opt in opts:
1057-
opt.optimize()
1058-
model_proto = graph.make_model(doc_string, optimize=optimize)
1059-
return model_proto
1060-
except Exception:
1061-
# degradation to non-optimized model proto
1062-
type_, value_, traceback_ = sys.exc_info()
1063-
ex_ext = traceback.format_exception(type_, value_, traceback_)
1064-
print("NON-CRITICAL error in optimizer: ", ex_ext)
1065-
return None
1045+
def optimize_graph(graph, debug=False):
1046+
return optimizer.optimize_graph(graph, debug)
10661047

10671048
@staticmethod
1068-
def optimize_graph_with_model_proto(onnx_model_proto, debug=False):
1049+
def optimize_model_proto(onnx_model_proto, debug=False):
10691050
"""Optimize the model proto, for example: eliminating all useless Transpose pairs.
10701051
10711052
Returns:
@@ -1074,16 +1055,10 @@ def optimize_graph_with_model_proto(onnx_model_proto, debug=False):
10741055
"""
10751056
try:
10761057
kwargs = GraphUtil.get_onnx_model_properties(onnx_model_proto)
1077-
g = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1078-
1079-
opts = [TransposeOptimizer(g, output_names=g.outputs, debug=debug),
1080-
IdentityOptimizer(g, output_names=g.outputs, debug=debug)
1081-
]
1082-
for opt in opts:
1083-
opt.optimize()
1084-
1085-
model_proto = g.make_model(onnx_model_proto.graph.doc_string,
1086-
graph_name=onnx_model_proto.graph.name, **kwargs)
1058+
graph = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1059+
graph = GraphUtil.optimize_graph(graph, debug)
1060+
model_proto = graph.make_model(onnx_model_proto.graph.doc_string,
1061+
graph_name=onnx_model_proto.graph.name, **kwargs)
10871062

10881063
if onnx_model_proto.metadata_props:
10891064
metadata_props = {p.key: p.value for p in onnx_model_proto.metadata_props}
@@ -1123,11 +1098,12 @@ def create_graph_from_onnx_model(onnx_model_proto):
11231098
# apply shape inference on the model
11241099
inferred_model = shape_inference.infer_shapes(onnx_model_proto)
11251100
graph_proto = inferred_model.graph
1126-
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto)
1101+
opset_version = onnx_model_proto.opset_import[0].version
1102+
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto, opset_version)
11271103
return main_graph
11281104

11291105
@staticmethod
1130-
def create_graph_from_onnx_graph(graph_proto):
1106+
def create_graph_from_onnx_graph(graph_proto, opset_version=None):
11311107
"""Create Graph loading onnx graph proto."""
11321108
output_shapes = {}
11331109
output_dtypes = {}
@@ -1154,7 +1130,7 @@ def create_graph_from_onnx_graph(graph_proto):
11541130
for n in graph_proto.output:
11551131
output_names.append(n.name)
11561132

1157-
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, None, None, output_names)
1133+
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, opset_version, None, output_names)
11581134
const_nodes = GraphUtil._parse_graph_initializer(g, graph_proto)
11591135
GraphUtil._parse_graph_input(g, graph_proto, [n.name for n in const_nodes])
11601136

tf2onnx/optimizer/__init__.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,38 @@
66
from __future__ import print_function
77
from __future__ import unicode_literals
88

9-
from .identity_optimizer import IdentityOptimizer
10-
from .transpose_optimizer import TransposeOptimizer
9+
import sys
10+
import traceback
1111

12-
__all__ = [
13-
"IdentityOptimizer",
14-
"TransposeOptimizer",
15-
]
12+
from tf2onnx.optimizer.identity_optimizer import IdentityOptimizer
13+
from tf2onnx.optimizer.merge_duplicated_nodes_optimizer import MergeDuplicatedNodesOptimizer
14+
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
15+
16+
# pylint: disable=missing-docstring, broad-except
17+
18+
# optimizer sequence need to be considered carefully
19+
_optimizers = {
20+
"transpose_opt": TransposeOptimizer,
21+
# merge_duplicated_nodes should be used after transpose_opt
22+
# for transpose_opt may have some trans nodes that can be merge
23+
"merge_duplicated_nodes": MergeDuplicatedNodesOptimizer,
24+
"identity_opt": IdentityOptimizer,
25+
}
26+
27+
28+
def optimize_graph(graph, debug=False):
29+
try:
30+
opts = _get_optimizers()
31+
for opt in opts.values():
32+
graph = opt(debug=debug).optimize(graph)
33+
return graph
34+
except Exception:
35+
# degradation to non-optimized model proto
36+
type_, value_, traceback_ = sys.exc_info()
37+
ex_ext = traceback.format_exception(type_, value_, traceback_)
38+
print("NON-CRITICAL error in optimizer: ", ex_ext)
39+
return None
40+
41+
42+
def _get_optimizers():
43+
return _optimizers

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,25 @@
1212
log = logging.getLogger("tf2onnx.optimizer.identity_optimizer")
1313

1414

15-
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
16-
# FIXME:
17-
# pylint: disable=unused-variable
15+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable
1816

1917

2018
class IdentityOptimizer(object):
2119
"""Identity Optimizer."""
2220

23-
def __init__(self, main_graph, output_names, debug=False):
24-
self._g = main_graph
25-
self._output_names = [name.split(":")[0] for name in output_names]
21+
def __init__(self, debug=False):
2622
self._debug = debug
23+
self._g = None
2724

28-
def optimize(self):
25+
def optimize(self, graph):
26+
self._g = graph
2927
previous_counter = self._g.dump_node_statistics()
3028
self._optimize_recursively(self._g)
3129
current_counter = self._g.dump_node_statistics()
3230
identity_cnt = current_counter["Identity"]
3331
current_counter.subtract(previous_counter)
3432
log.info(" %d identity op(s) left, ops diff after identity optimization: %s", identity_cnt, current_counter)
33+
return self._g
3534

3635
def _optimize_recursively(self, g):
3736
self._optimize(g)
@@ -64,13 +63,14 @@ def _optimize(self, g):
6463

6564
self._g.topological_sort(self._g.get_nodes())
6665

67-
68-
def _handle_non_graph_output_identity(self, graph, identity):
66+
@staticmethod
67+
def _handle_non_graph_output_identity(graph, identity):
6968
graph.replace_all_inputs(graph.get_nodes(), identity.output[0], identity.input[0])
7069
graph.remove_node(identity.name)
7170
return True
7271

73-
def _handle_graph_output_identity(self, graph, identity, graph_outputs):
72+
@staticmethod
73+
def _handle_graph_output_identity(graph, identity, graph_outputs):
7474
input_id = identity.input[0]
7575
input_node = identity.inputs[0]
7676

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Merge Duplicated Nodes Optimizer.
5+
Remove duplicate nodes except identity nodes which should be handled by identity optimizer.
6+
for example, node a is input of node b and node c, and computation of node b, c are same such as "abs" op.
7+
then b and c can be merged into one node to avoid duplicated computation
8+
"""
9+
10+
from collections import defaultdict, namedtuple
11+
12+
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
13+
14+
15+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
16+
17+
18+
class MergeDuplicatedNodesOptimizer(GraphOptimizerBase):
19+
"""Remove duplicate nodes.
20+
"""
21+
_key_to_group_nodes = namedtuple("key", "type input")
22+
23+
def __init__(self, name="MergeDuplicatedNodesOptimizer", debug=False):
24+
super(MergeDuplicatedNodesOptimizer, self).__init__(name=name)
25+
26+
self._graph_can_be_optimized = True
27+
28+
def _optimize(self, graph):
29+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
30+
31+
def _optimize_at_current_graph_level(self, graph):
32+
while self._graph_can_be_optimized:
33+
self._graph_can_be_optimized = False
34+
self._merge_duplicated_nodes(graph)
35+
return graph
36+
37+
def _merge_duplicated_nodes(self, graph):
38+
# "duplicated" means: op_type, input and attribute are same
39+
# while attr is un-hashable so doesn't include it when grouping nodes
40+
nodes_groups = self._group_nodes_by_type_inputs(graph)
41+
for _, nodes_group in nodes_groups.items():
42+
if self._skip_node_type(nodes_group[0]):
43+
continue
44+
self._del_nodes_if_duplicated(nodes_group, graph)
45+
46+
def _group_nodes_by_type_inputs(self, graph):
47+
res = defaultdict(list)
48+
for node in graph.get_nodes():
49+
res[self._key_to_group_nodes(node.type, tuple(node.input))].append(node)
50+
return res
51+
52+
def _del_nodes_if_duplicated(self, nodes_group, graph):
53+
# input and op type of nodes in same group are same,
54+
# and if their attributes are also same then they are duplicated
55+
while len(nodes_group) > 1:
56+
unprocessed_node = []
57+
nodes_to_process = [nodes_group[0]]
58+
for node in nodes_group[1:]:
59+
if node.attr == nodes_to_process[0].attr:
60+
nodes_to_process.append(node)
61+
else:
62+
unprocessed_node.append(node)
63+
64+
self._merge_nodes_that_are_duplicated(nodes_to_process, graph)
65+
nodes_group = unprocessed_node
66+
67+
def _merge_nodes_that_are_duplicated(self, nodes_to_process, graph):
68+
# node's output may not all be used, so have to select the one that uses most of node's outputs
69+
nodes_to_process.sort(key=self._len_of_node_output, reverse=True)
70+
node_to_retain = nodes_to_process[0]
71+
for node_to_delete in nodes_to_process[1:]:
72+
# if one of the output is graph's output then it can't be deleted
73+
if set(node_to_delete.output).intersection(set(graph.outputs)):
74+
continue
75+
for old_input, new_input in zip(node_to_delete.output, node_to_retain.output):
76+
graph.replace_all_inputs(graph.get_nodes(), old_input, new_input)
77+
graph.remove_node(node_to_delete.name)
78+
self._graph_can_be_optimized = True
79+
80+
@staticmethod
81+
def _skip_node_type(node):
82+
# identity node will be handled by identity optimizer so skip it
83+
if node.type in ["Identity"]:
84+
return True
85+
if node.is_graph_input():
86+
return True
87+
return False
88+
89+
@staticmethod
90+
def _len_of_node_output(node):
91+
return len(node.output)

0 commit comments

Comments
 (0)