Skip to content

Commit 1e342bd

Browse files
committed
enhance: if graph been optimized, then optimzers should be run again.
1 parent 4c5423f commit 1e342bd

File tree

7 files changed

+23
-8
lines changed

7 files changed

+23
-8
lines changed

tf2onnx/optimizer/__init__.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,21 @@ def optimize_graph(graph):
3939

4040
before = graph.dump_node_statistics()
4141
opts = _get_optimizers()
42-
for name, factory in opts.items():
43-
try:
44-
logger.verbose("Apply %s", name)
45-
current = copy.deepcopy(graph)
46-
graph = factory().optimize(current)
47-
except Exception: # pylint: disable=broad-except
48-
# if current optimizer fails, continue with other optimizers
49-
logger.warning("Failed to apply %s", name, exc_info=1)
42+
continue_flag = True
43+
while continue_flag:
44+
continue_flag = False
45+
for name, factory in opts.items():
46+
try:
47+
logger.verbose("Apply %s", name)
48+
current = copy.deepcopy(graph)
49+
opt = factory()
50+
graph = opt.optimize(current)
51+
if not continue_flag:
52+
continue_flag = opt.graph_been_opt
53+
54+
except Exception: # pylint: disable=broad-except
55+
# if current optimizer fails, continue with other optimizers
56+
logger.warning("Failed to apply %s", name, exc_info=1)
5057

5158
after = graph.dump_node_statistics()
5259
diff = copy.deepcopy(after)

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _optimize_at_current_graph_level(self, graph):
4242
continue
4343
if self._fold_node(op, graph):
4444
graph_changed = True
45+
self.graph_been_opt = True
4546
return graph
4647

4748
@staticmethod

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def _optimize_at_current_graph_level(self, g):
3939
else:
4040
ret = self._handle_non_graph_output_identity(g, n)
4141
has_update = ret
42+
if ret:
43+
self.graph_been_opt = True
4244
return g
4345

4446
@staticmethod

tf2onnx/optimizer/loop_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _optimize_at_current_graph_level(self, g):
3232
has_update_tmp = self._try_move_transpose_out_of_body_graph(n)
3333
if has_update_tmp:
3434
has_update = True
35+
self.graph_been_opt = True
3536
return g
3637

3738
@staticmethod

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def _optimize_at_current_graph_level(self, graph):
3232
while self._graph_can_be_optimized:
3333
self._graph_can_be_optimized = False
3434
self._merge_duplicated_nodes(graph)
35+
if self._graph_can_be_optimized:
36+
self.graph_been_opt = True
3537
return graph
3638

3739
def _merge_duplicated_nodes(self, graph):

tf2onnx/optimizer/optimizer_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class GraphOptimizerBase(object):
1616

1717
def __init__(self):
1818
self._logger = logging.getLogger('.'.join(__name__.split('.')[:-1] + [self.__class__.__name__]))
19+
self.graph_been_opt = False
1920

2021
@property
2122
def logger(self):

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _optimize_at_current_graph_level(self, graph):
147147
if is_nhwc_transpose(n):
148148
if self._handle_nhwc_tranpose(n):
149149
no_action = False
150+
self.graph_been_opt = True
150151
iteration_cnt += 1
151152
# need break, because handler may change nodes set, making the n stale object
152153
# referencing already deleted elements

0 commit comments

Comments
 (0)