Skip to content

Commit c5b7992

Browse files
committed
refactor optimizer
1 parent 84edabc commit c5b7992

File tree

8 files changed

+84
-106
lines changed

8 files changed

+84
-106
lines changed

tf2onnx/convert.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def main():
8585
if args.debug:
8686
utils.set_debug_mode(True)
8787

88+
logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME)
89+
8890
# override unknown dimensions from -1 to 1 (aka batchsize 1) since not every runtime does
8991
# support unknown dimensions.
9092
utils.ONNX_UNKNOWN_DIMENSION = args.unknown_dim
@@ -126,11 +128,8 @@ def main():
126128

127129
model_proto = g.make_model("converted from {}".format(model_path))
128130

129-
new_model_proto = GraphUtil.optimize_model_proto(model_proto)
130-
if new_model_proto:
131-
model_proto = new_model_proto
132-
else:
133-
print("NON-CRITICAL, optimizers are not applied successfully")
131+
logger.info("\n")
132+
model_proto = GraphUtil.optimize_model_proto(model_proto)
134133

135134
# write onnx graph
136135
if args.output:

tf2onnx/graph.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import collections
1313
import copy
1414
import logging
15-
import sys
16-
import traceback
1715
import six
1816
import numpy as np
1917

@@ -1065,27 +1063,18 @@ def optimize_model_proto(onnx_model_proto):
10651063
"""Optimize the model proto, for example: eliminating all useless Transpose pairs.
10661064
10671065
Returns:
1068-
model proto after optimization, if optimizer run successfully
1069-
or None, if exceptions happens
1066+
model proto after optimization
10701067
"""
1071-
try:
1072-
kwargs = GraphUtil.get_onnx_model_properties(onnx_model_proto)
1073-
graph = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1074-
graph = GraphUtil.optimize_graph(graph)
1075-
model_proto = graph.make_model(onnx_model_proto.graph.doc_string,
1076-
graph_name=onnx_model_proto.graph.name, **kwargs)
1077-
1078-
if onnx_model_proto.metadata_props:
1079-
metadata_props = {p.key: p.value for p in onnx_model_proto.metadata_props}
1080-
helper.set_model_props(model_proto, metadata_props)
1081-
return model_proto
1082-
except Exception:
1083-
# sometimes, onnx shape inference will fail for some reason, in this case,
1084-
# we just log the error, and skip the transpose optimizer.
1085-
type_, value_, traceback_ = sys.exc_info()
1086-
ex_ext = traceback.format_exception(type_, value_, traceback_)
1087-
print("NON-CRITICAL error in optimizer: ", ex_ext)
1088-
return None
1068+
kwargs = GraphUtil.get_onnx_model_properties(onnx_model_proto)
1069+
graph = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1070+
graph = GraphUtil.optimize_graph(graph)
1071+
model_proto = graph.make_model(onnx_model_proto.graph.doc_string,
1072+
graph_name=onnx_model_proto.graph.name, **kwargs)
1073+
1074+
if onnx_model_proto.metadata_props:
1075+
metadata_props = {p.key: p.value for p in onnx_model_proto.metadata_props}
1076+
helper.set_model_props(model_proto, metadata_props)
1077+
return model_proto
10891078

10901079
@staticmethod
10911080
def get_onnx_model_properties(onnx_model_proto):

tf2onnx/optimizer/__init__.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,51 @@
66
from __future__ import print_function
77
from __future__ import unicode_literals
88

9-
import sys
10-
import traceback
119
from collections import OrderedDict
10+
import copy
1211

13-
from tf2onnx.optimizer.const_fold_optimizer import ConstFoldOptimizer
14-
from tf2onnx.optimizer.identity_optimizer import IdentityOptimizer
15-
from tf2onnx.optimizer.merge_duplicated_nodes_optimizer import MergeDuplicatedNodesOptimizer
16-
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
17-
18-
# pylint: disable=missing-docstring, broad-except
12+
from .const_fold_optimizer import ConstFoldOptimizer
13+
from .identity_optimizer import IdentityOptimizer
14+
from .merge_duplicated_nodes_optimizer import MergeDuplicatedNodesOptimizer
15+
from .transpose_optimizer import TransposeOptimizer
16+
from .. import logging
1917

2018
# optimizer sequence need to be considered carefully
2119
_optimizers = OrderedDict([
22-
("transpose_opt", TransposeOptimizer),
23-
("fold_const", ConstFoldOptimizer),
24-
# merge_duplicated_nodes should be used after transpose_opt
25-
# for transpose_opt may have some trans nodes that can be merge
26-
("merge_duplicated_nodes", MergeDuplicatedNodesOptimizer),
27-
("identity_opt", IdentityOptimizer),
20+
("reduce_transpose", TransposeOptimizer),
21+
("fold_constants", ConstFoldOptimizer),
22+
# merge_duplication should be used after reduce_transpose
23+
# for reduce_transpose may have some trans nodes that can be merge
24+
("merge_duplication", MergeDuplicatedNodesOptimizer),
25+
("reduce_identity", IdentityOptimizer),
2826
])
2927

3028

31-
def optimize_graph(graph):
32-
try:
33-
opts = _get_optimizers()
34-
for opt in opts.values():
35-
graph = opt().optimize(graph)
36-
37-
graph.update_proto()
38-
return graph
39-
except Exception:
40-
# degradation to non-optimized model proto
41-
type_, value_, traceback_ = sys.exc_info()
42-
ex_ext = traceback.format_exception(type_, value_, traceback_)
43-
print("NON-CRITICAL error in optimizer: ", ex_ext)
44-
return None
45-
46-
4729
def _get_optimizers():
4830
return _optimizers
31+
32+
33+
def optimize_graph(graph):
34+
""" Optimize graph, return optimized graph. No throw. """
35+
logger = logging.getLogger(__name__)
36+
logger.info("Optimizing ONNX model")
37+
38+
before = graph.dump_node_statistics()
39+
opts = _get_optimizers()
40+
for name, factory in opts.items():
41+
try:
42+
logger.verbose("Apply %s", name)
43+
current = copy.deepcopy(graph)
44+
graph = factory().optimize(current)
45+
except Exception: # pylint: disable=broad-except
46+
# if current optimizer fails, continue with other optimizers
47+
logger.warning("Failed to apply %s", name, exc_info=1)
48+
49+
after = graph.dump_node_statistics()
50+
diff = copy.deepcopy(after)
51+
diff.subtract(before)
52+
diff = ["{} {} ({}->{})".format(k, str(v) if v < 0 else '+' + str(v), before.get(k, 0), after.get(k, 0))
53+
for k, v in diff.most_common() if v != 0]
54+
logger.info("After optimization: %s", ', '.join(diff) if diff else "no change")
55+
56+
return graph

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
for example, input of transpose node is const then we can do transpose statically instead of at runtime
77
"""
88

9-
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
10-
from tf2onnx import utils
9+
from .. import utils
10+
from .optimizer_base import GraphOptimizerBase
1111

1212
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
1313

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from __future__ import unicode_literals
99

10-
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
10+
from .optimizer_base import GraphOptimizerBase
1111

1212

1313
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
@@ -16,39 +16,20 @@
1616
class IdentityOptimizer(GraphOptimizerBase):
1717
"""Identity Optimizer."""
1818

19-
def __init__(self):
19+
def __init__(self): # pylint: disable=useless-super-delegation
2020
super(IdentityOptimizer, self).__init__()
21-
self._g = None
22-
23-
def optimize(self, graph):
24-
self._g = graph
25-
previous_counter = self._g.dump_node_statistics()
26-
self._optimize_recursively(self._g)
27-
current_counter = self._g.dump_node_statistics()
28-
identity_cnt = current_counter["Identity"]
29-
self.logger.info(" %d identity op(s) left", identity_cnt)
30-
self._print_stat_diff(previous_counter, current_counter)
31-
return self._g
32-
33-
def _optimize_recursively(self, g):
34-
self._optimize(g)
35-
nodes = [n for n in g.get_nodes()]
36-
for n in nodes:
37-
body_graphs = n.get_body_graphs()
38-
if body_graphs:
39-
for attr, b_g in body_graphs.items():
40-
self.logger.debug("start handling subgraph of %s's attribute %s", n.name, attr)
41-
self._optimize_recursively(b_g)
42-
self.logger.debug("finish handling subgraph of %s's attribute %s", n.name, attr)
43-
44-
def _optimize(self, g):
21+
22+
def _optimize(self, graph):
23+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
24+
25+
def _optimize_at_current_graph_level(self, g):
4526
has_update = True
4627
while has_update:
4728
has_update = False
4829
nodes = [n for n in g.get_nodes() if n.type == "Identity"]
4930
for n in nodes:
5031
if n.graph is None:
51-
self.logger.info("node has been removed from this graph, skip")
32+
self.logger.debug("node has been removed from this graph, skip")
5233
continue
5334

5435
graph_outputs = set(n.output).intersection(g.outputs)
@@ -58,8 +39,7 @@ def _optimize(self, g):
5839
else:
5940
ret = self._handle_non_graph_output_identity(g, n)
6041
has_update = ret
61-
62-
self._g.topological_sort(self._g.get_nodes())
42+
return g
6343

6444
@staticmethod
6545
def _handle_non_graph_output_identity(graph, identity):

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from collections import defaultdict, namedtuple
1111

12-
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
12+
from .optimizer_base import GraphOptimizerBase
1313

1414
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
1515

tf2onnx/optimizer/optimizer_base.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from __future__ import unicode_literals
77

8-
from tf2onnx import logging, utils
8+
import copy
9+
10+
from .. import logging, utils
911

1012

1113
class GraphOptimizerBase(object):
@@ -24,11 +26,15 @@ def is_debug_mode(self):
2426
return utils.is_debug_mode()
2527

2628
def optimize(self, graph):
27-
original_node_statistics = graph.dump_node_statistics()
29+
""" Optimize graph, return optimized graph """
30+
before = graph.dump_node_statistics()
31+
2832
graph = self._optimize(graph)
33+
graph.update_proto()
2934
graph.delete_unused_nodes(graph.outputs)
30-
node_statistics = graph.dump_node_statistics()
31-
self._print_stat_diff(original_node_statistics, node_statistics)
35+
36+
after = graph.dump_node_statistics()
37+
self._print_stat_diff(before, after)
3238
return graph
3339

3440
def _optimize(self, graph):
@@ -48,14 +54,13 @@ def _apply_optimization(graph, optimize_func):
4854
body_graphs = node.get_body_graphs()
4955
if body_graphs:
5056
for attr, b_g in body_graphs.items():
51-
b_g = optimize_func(b_g)
57+
b_g = GraphOptimizerBase._apply_optimization(b_g, optimize_func)
5258
node.set_body_graph_as_attr(attr, b_g)
5359
return graph
5460

55-
def _print_stat_diff(self, nodes_original, nodes_after_optimized):
56-
nodes_after_optimized.subtract(nodes_original)
57-
res = {}
58-
for key, value in nodes_after_optimized.items():
59-
if value != 0:
60-
res[key] = value
61-
self.logger.info("the optimization gain is %s", res)
61+
def _print_stat_diff(self, before, after):
62+
diff = copy.deepcopy(after)
63+
diff.subtract(before)
64+
diff = ["{} {} ({}->{})".format(k, str(v) if v < 0 else '+' + str(v), before.get(k, 0), after.get(k, 0))
65+
for k, v in diff.most_common() if v != 0]
66+
self.logger.verbose(', '.join(diff) if diff else "no change")

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
import numpy as np
1010

11-
from tf2onnx import utils
12-
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
11+
from .. import utils
12+
from .optimizer_base import GraphOptimizerBase
1313

1414

1515
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,abstract-method
@@ -164,10 +164,7 @@ def optimize(self, graph):
164164

165165
current_counter = self._g.dump_node_statistics()
166166
transpose_cnt = current_counter["Transpose"]
167-
self.logger.info(" %d transpose op(s) left", transpose_cnt)
168167
self._print_stat_diff(previous_counter, current_counter)
169-
if transpose_cnt > 2:
170-
self.logger.warning("please try add --fold_const to help remove more transpose")
171168
return self._g
172169

173170
def _initialize_handlers(self):

0 commit comments

Comments
 (0)