Skip to content

Commit dc792f7

Browse files
authored
Merge pull request #396 from zhijxu-MS/push_branch
add graph optimizer "merge duplicated nodes"
2 parents eb31fc3 + 469435b commit dc792f7

10 files changed

+343
-73
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from common import get_test_config
1919
from tf2onnx import utils
2020
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
21+
from tf2onnx import optimizer
2122

2223

2324
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -140,6 +141,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
140141
with tf.Session() as sess:
141142
g = process_tf_graph(sess.graph, opset=self.config.opset, output_names=output_names_with_port,
142143
target=self.config.target, **process_args)
144+
g = optimizer.optimize_graph(g)
143145
actual = self._run_backend(g, output_names_with_port, onnx_feed_dict)
144146

145147
for expected_val, actual_val in zip(expected, actual):

tests/run_pretrained_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import tf2onnx
3131
from tf2onnx import loader
3232
from tf2onnx import utils
33-
from tf2onnx.graph import GraphUtil
33+
from tf2onnx import optimizer
3434
from tf2onnx.tfonnx import process_tf_graph
3535

3636
# pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda
@@ -270,7 +270,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
270270
onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override,
271271
input_names=inputs.keys())
272272
model_proto = onnx_graph.make_model("converted from tf2onnx")
273-
new_model_proto = GraphUtil.optimize_graph(onnx_graph, "test", debug=debug)
273+
new_model_proto = optimizer.optimize_graph(onnx_graph, debug=debug).make_model("optimized")
274274
if new_model_proto:
275275
model_proto = new_model_proto
276276
else:

tests/test_optimizers.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
2727

2828
origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin")
2929

30-
new_proto = GraphUtil.optimize_graph_with_model_proto(origin_proto)
30+
new_proto = GraphUtil.optimize_model_proto(origin_proto)
3131

3232
self.assertTrue(new_proto, msg="model proto after optimizer should not be None")
3333

@@ -287,7 +287,115 @@ def test_identity_in_subgraph_non_graph_output(self):
287287
self.run_identity_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
288288
model_proto, remaining_identity_num=0)
289289

290-
# Tranpose Optimizer Tests End
290+
# Identity Optimizer Tests End
291+
292+
# Merge Duplicated Nodes Optimizer Tests Start
293+
294+
def run_merge_duplicated_nodes_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
295+
op_type=None, remaining_op_num=None, debug=False, rtol=1e-07):
296+
self.run_and_compare(output_names_with_port, onnx_feed_dict, origin_proto, op_type=op_type,
297+
remaining_op_num=remaining_op_num, debug=debug, rtol=rtol)
298+
299+
def test_duplicated_duplicated_input(self):
300+
# same input or not
301+
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])
302+
node1 = helper.make_node('Add', inputs=["X", "X"], outputs=["value1"])
303+
node2 = helper.make_node('Add', inputs=["value1", "X"], outputs=["value2"])
304+
node3 = helper.make_node("Mul", ["value0", "value2"], ["value3"])
305+
node4 = helper.make_node("Mul", ["value1", "value3"], ["OUT"])
306+
307+
graph = helper.make_graph(
308+
[node0, node1, node2, node3, node4],
309+
"test_duplicated_duplicated_input",
310+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5, 5))],
311+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (5, 5))],
312+
)
313+
314+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
315+
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
316+
op_type="Add", remaining_op_num=2)
317+
318+
def test_duplicated_duplicated_attributes(self):
319+
# same attr or not
320+
node0 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value0"], axes=[0], keepdims=0)
321+
node1 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value1"], axes=[0], keepdims=0)
322+
node2 = helper.make_node('ReduceSum', inputs=["X"], outputs=["value2"], axes=[1], keepdims=0)
323+
node3 = helper.make_node('Add', inputs=["value0", "value1"], outputs=["value3"])
324+
node4 = helper.make_node("Mul", ["value2", "value3"], ["OUT"])
325+
326+
graph = helper.make_graph(
327+
[node0, node1, node2, node3, node4],
328+
"test_duplicated_duplicated_attributes",
329+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5, 5))],
330+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (5,))],
331+
)
332+
333+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
334+
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
335+
op_type="ReduceSum", remaining_op_num=2)
336+
337+
def test_duplicated_node_is_graph_output(self):
338+
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])
339+
node1 = helper.make_node('Add', inputs=["X", "X"], outputs=["value1"])
340+
node2 = helper.make_node('Add', inputs=["value1", "X"], outputs=["value2"])
341+
342+
graph = helper.make_graph(
343+
[node0, node1, node2],
344+
"test_duplicated_node_is_graph_output",
345+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5, 5))],
346+
[helper.make_tensor_value_info("value1", TensorProto.FLOAT, (5, 5)),
347+
helper.make_tensor_value_info("value2", TensorProto.FLOAT, (5, 5))],
348+
)
349+
350+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
351+
self.run_merge_duplicated_nodes_compare(["value1", "value2"],
352+
{"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
353+
op_type="Add", remaining_op_num=2)
354+
355+
def test_duplicated_different_output_length(self):
356+
node0 = helper.make_node('Dropout', inputs=["X"], outputs=["value0"])
357+
node1 = helper.make_node('Dropout', inputs=["X"], outputs=["value1", "mask"])
358+
node2 = helper.make_node('Dropout', inputs=["value1"], outputs=["value2"])
359+
360+
graph = helper.make_graph(
361+
[node0, node1, node2],
362+
"test_duplicated_different_output_length",
363+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5,))],
364+
[helper.make_tensor_value_info("value1", TensorProto.FLOAT, (5,)),
365+
helper.make_tensor_value_info("mask", TensorProto.BOOL, (5,)),
366+
helper.make_tensor_value_info("value2", TensorProto.FLOAT, (5,))],
367+
)
368+
369+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
370+
self.run_merge_duplicated_nodes_compare(["value1", "mask", "value2"],
371+
{"X": np.random.randn(5,).astype(np.float32)},
372+
model_proto,
373+
op_type="Dropout", remaining_op_num=2)
374+
375+
def test_duplicated_need_multiple_run(self):
376+
node00 = helper.make_node('Log', inputs=["X"], outputs=["value00"])
377+
node01 = helper.make_node('Log', inputs=["value00"], outputs=["value01"])
378+
node02 = helper.make_node('Log', inputs=["value01"], outputs=["value02"])
379+
380+
node10 = helper.make_node('Log', inputs=["X"], outputs=["value10"])
381+
node11 = helper.make_node('Log', inputs=["value10"], outputs=["value11"])
382+
node12 = helper.make_node('Log', inputs=["value11"], outputs=["value12"])
383+
384+
res = helper.make_node('Add', inputs=["value02", "value12"], outputs=["res"])
385+
386+
graph = helper.make_graph(
387+
[node00, node01, node02, node10, node11, node12, res],
388+
"test_duplicated_node_is_graph_output",
389+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5,))],
390+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (5,))],
391+
)
392+
393+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
394+
self.run_merge_duplicated_nodes_compare(["res"], {"X": np.random.randn(5,).astype(np.float32)},
395+
model_proto,
396+
op_type="Log", remaining_op_num=3)
397+
# Merge Duplicated Nodes Optimizer Tests End
398+
291399

292400
if __name__ == "__main__":
293401
unittest_main()

tf2onnx/convert.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,9 @@ def main():
117117
output_names=outputs,
118118
inputs_as_nchw=args.inputs_as_nchw)
119119

120-
model_proto = g.make_model("converted from {}".format(args.input))
120+
model_proto = g.make_model("converted from {}".format(model_path))
121121

122-
new_model_proto = GraphUtil.optimize_graph(g, "converted from {}".format(model_path),
123-
optimize=not args.continue_on_error)
122+
new_model_proto = GraphUtil.optimize_model_proto(model_proto)
124123
if new_model_proto:
125124
model_proto = new_model_proto
126125
else:

tf2onnx/graph.py

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import six
1818
import numpy as np
1919

20-
from onnx import helper, numpy_helper, optimizer, shape_inference, OperatorSetIdProto, AttributeProto
20+
from onnx import helper, numpy_helper, shape_inference, OperatorSetIdProto, AttributeProto
2121
from tf2onnx import utils, __version__
2222
from tf2onnx.utils import port_name, find_opset
23-
from tf2onnx.optimizer import IdentityOptimizer, TransposeOptimizer
23+
from tf2onnx import optimizer
2424
from tf2onnx.schemas import get_schema
2525

2626
logging.basicConfig(level=logging.INFO)
@@ -1042,30 +1042,11 @@ class GraphUtil(object):
10421042
"""Utilities for Graph manipulation."""
10431043

10441044
@staticmethod
1045-
def optimize_graph(graph, doc_string, optimize=None, debug=False):
1046-
"""Optimize the graph, for example: eliminating all useless Transpose/Identity pairs.
1047-
1048-
Returns:
1049-
model proto after optimization, if optimizer run successfully
1050-
or None, if exceptions happen
1051-
"""
1052-
try:
1053-
opts = [TransposeOptimizer(graph, output_names=graph.outputs, debug=debug),
1054-
IdentityOptimizer(graph, output_names=graph.outputs, debug=debug)
1055-
]
1056-
for opt in opts:
1057-
opt.optimize()
1058-
model_proto = graph.make_model(doc_string, optimize=optimize)
1059-
return model_proto
1060-
except Exception:
1061-
# degradation to non-optimized model proto
1062-
type_, value_, traceback_ = sys.exc_info()
1063-
ex_ext = traceback.format_exception(type_, value_, traceback_)
1064-
print("NON-CRITICAL error in optimizer: ", ex_ext)
1065-
return None
1045+
def optimize_graph(graph, debug=False):
1046+
return optimizer.optimize_graph(graph, debug)
10661047

10671048
@staticmethod
1068-
def optimize_graph_with_model_proto(onnx_model_proto, debug=False):
1049+
def optimize_model_proto(onnx_model_proto, debug=False):
10691050
"""Optimize the model proto, for example: eliminating all useless Transpose pairs.
10701051
10711052
Returns:
@@ -1074,16 +1055,10 @@ def optimize_graph_with_model_proto(onnx_model_proto, debug=False):
10741055
"""
10751056
try:
10761057
kwargs = GraphUtil.get_onnx_model_properties(onnx_model_proto)
1077-
g = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1078-
1079-
opts = [TransposeOptimizer(g, output_names=g.outputs, debug=debug),
1080-
IdentityOptimizer(g, output_names=g.outputs, debug=debug)
1081-
]
1082-
for opt in opts:
1083-
opt.optimize()
1084-
1085-
model_proto = g.make_model(onnx_model_proto.graph.doc_string,
1086-
graph_name=onnx_model_proto.graph.name, **kwargs)
1058+
graph = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1059+
graph = GraphUtil.optimize_graph(graph, debug)
1060+
model_proto = graph.make_model(onnx_model_proto.graph.doc_string,
1061+
graph_name=onnx_model_proto.graph.name, **kwargs)
10871062

10881063
if onnx_model_proto.metadata_props:
10891064
metadata_props = {p.key: p.value for p in onnx_model_proto.metadata_props}
@@ -1123,11 +1098,12 @@ def create_graph_from_onnx_model(onnx_model_proto):
11231098
# apply shape inference on the model
11241099
inferred_model = shape_inference.infer_shapes(onnx_model_proto)
11251100
graph_proto = inferred_model.graph
1126-
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto)
1101+
opset_version = onnx_model_proto.opset_import[0].version
1102+
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto, opset_version)
11271103
return main_graph
11281104

11291105
@staticmethod
1130-
def create_graph_from_onnx_graph(graph_proto):
1106+
def create_graph_from_onnx_graph(graph_proto, opset_version=None):
11311107
"""Create Graph loading onnx graph proto."""
11321108
output_shapes = {}
11331109
output_dtypes = {}
@@ -1154,15 +1130,15 @@ def create_graph_from_onnx_graph(graph_proto):
11541130
for n in graph_proto.output:
11551131
output_names.append(n.name)
11561132

1157-
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, None, None, output_names)
1133+
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, opset_version, None, output_names)
11581134
const_nodes = GraphUtil._parse_graph_initializer(g, graph_proto)
11591135
GraphUtil._parse_graph_input(g, graph_proto, [n.name for n in const_nodes])
11601136

11611137
for n in g.get_nodes():
11621138
for attr_name, attr_val in n.attr.items():
11631139
if attr_val.HasField('g'):
11641140
# it was assumed that the a.g has inferred shapes/dtypes.
1165-
sub_g = GraphUtil.create_graph_from_onnx_graph(attr_val.g)
1141+
sub_g = GraphUtil.create_graph_from_onnx_graph(attr_val.g, opset_version)
11661142
n.set_body_graph_as_attr(attr_name, sub_g)
11671143
return g
11681144

tf2onnx/optimizer/__init__.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,39 @@
66
from __future__ import print_function
77
from __future__ import unicode_literals
88

9-
from .identity_optimizer import IdentityOptimizer
10-
from .transpose_optimizer import TransposeOptimizer
9+
import sys
10+
import traceback
11+
from collections import OrderedDict
1112

12-
__all__ = [
13-
"IdentityOptimizer",
14-
"TransposeOptimizer",
15-
]
13+
from tf2onnx.optimizer.identity_optimizer import IdentityOptimizer
14+
from tf2onnx.optimizer.merge_duplicated_nodes_optimizer import MergeDuplicatedNodesOptimizer
15+
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
16+
17+
# pylint: disable=missing-docstring, broad-except
18+
19+
# optimizer sequence need to be considered carefully
20+
_optimizers = OrderedDict([
21+
("transpose_opt", TransposeOptimizer),
22+
# merge_duplicated_nodes should be used after transpose_opt
23+
# for transpose_opt may have some trans nodes that can be merge
24+
("merge_duplicated_nodes", MergeDuplicatedNodesOptimizer),
25+
("identity_opt", IdentityOptimizer),
26+
])
27+
28+
29+
def optimize_graph(graph, debug=False):
30+
try:
31+
opts = _get_optimizers()
32+
for opt in opts.values():
33+
graph = opt(debug=debug).optimize(graph)
34+
return graph
35+
except Exception:
36+
# degradation to non-optimized model proto
37+
type_, value_, traceback_ = sys.exc_info()
38+
ex_ext = traceback.format_exception(type_, value_, traceback_)
39+
print("NON-CRITICAL error in optimizer: ", ex_ext)
40+
return None
41+
42+
43+
def _get_optimizers():
44+
return _optimizers

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,32 @@
88
from __future__ import unicode_literals
99
import logging
1010

11-
logging.basicConfig(level=logging.INFO)
11+
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
12+
13+
1214
log = logging.getLogger("tf2onnx.optimizer.identity_optimizer")
1315

1416

15-
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
16-
# FIXME:
17-
# pylint: disable=unused-variable
17+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
1818

1919

20-
class IdentityOptimizer(object):
20+
class IdentityOptimizer(GraphOptimizerBase):
2121
"""Identity Optimizer."""
2222

23-
def __init__(self, main_graph, output_names, debug=False):
24-
self._g = main_graph
25-
self._output_names = [name.split(":")[0] for name in output_names]
26-
self._debug = debug
23+
def __init__(self, debug=False):
24+
super(IdentityOptimizer, self).__init__("IdentityOptimizer", debug)
2725

28-
def optimize(self):
26+
self._g = None
27+
28+
def optimize(self, graph):
29+
self._g = graph
2930
previous_counter = self._g.dump_node_statistics()
3031
self._optimize_recursively(self._g)
3132
current_counter = self._g.dump_node_statistics()
3233
identity_cnt = current_counter["Identity"]
3334
current_counter.subtract(previous_counter)
3435
log.info(" %d identity op(s) left, ops diff after identity optimization: %s", identity_cnt, current_counter)
36+
return self._g
3537

3638
def _optimize_recursively(self, g):
3739
self._optimize(g)
@@ -64,13 +66,14 @@ def _optimize(self, g):
6466

6567
self._g.topological_sort(self._g.get_nodes())
6668

67-
68-
def _handle_non_graph_output_identity(self, graph, identity):
69+
@staticmethod
70+
def _handle_non_graph_output_identity(graph, identity):
6971
graph.replace_all_inputs(graph.get_nodes(), identity.output[0], identity.input[0])
7072
graph.remove_node(identity.name)
7173
return True
7274

73-
def _handle_graph_output_identity(self, graph, identity, graph_outputs):
75+
@staticmethod
76+
def _handle_graph_output_identity(graph, identity, graph_outputs):
7477
input_id = identity.input[0]
7578
input_node = identity.inputs[0]
7679

0 commit comments

Comments
 (0)