Skip to content

Commit b93c773

Browse files
committed
fix review comments
1 parent 70a24ad commit b93c773

File tree

4 files changed

+25
-22
lines changed

4 files changed

+25
-22
lines changed

tf2onnx/graph.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,18 +1063,25 @@ def optimize_model_proto(onnx_model_proto):
10631063
"""Optimize the model proto, for example: eliminating all useless Transpose pairs.
10641064
10651065
Returns:
1066-
model proto after optimization
1066+
model proto after optimization, if optimizer run successfully
1067+
or onnx_model_proto, if exceptions happens
10671068
"""
1068-
kwargs = GraphUtil.get_onnx_model_properties(onnx_model_proto)
1069-
graph = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1070-
graph = GraphUtil.optimize_graph(graph)
1071-
model_proto = graph.make_model(onnx_model_proto.graph.doc_string,
1072-
graph_name=onnx_model_proto.graph.name, **kwargs)
1073-
1074-
if onnx_model_proto.metadata_props:
1075-
metadata_props = {p.key: p.value for p in onnx_model_proto.metadata_props}
1076-
helper.set_model_props(model_proto, metadata_props)
1077-
return model_proto
1069+
try:
1070+
kwargs = GraphUtil.get_onnx_model_properties(onnx_model_proto)
1071+
graph = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1072+
graph = GraphUtil.optimize_graph(graph)
1073+
model_proto = graph.make_model(onnx_model_proto.graph.doc_string,
1074+
graph_name=onnx_model_proto.graph.name, **kwargs)
1075+
1076+
if onnx_model_proto.metadata_props:
1077+
metadata_props = {p.key: p.value for p in onnx_model_proto.metadata_props}
1078+
helper.set_model_props(model_proto, metadata_props)
1079+
return model_proto
1080+
except Exception:
1081+
# sometimes, onnx shape inference will fail for some reason,
1082+
# return onnx_model_proto for this case
1083+
logger.warning("Failed to optimize model proto", exc_info=1)
1084+
return onnx_model_proto
10781085

10791086
@staticmethod
10801087
def get_onnx_model_properties(onnx_model_proto):

tf2onnx/optimizer/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
# optimizer sequence need to be considered carefully
1919
_optimizers = OrderedDict([
20-
("reduce_transpose", TransposeOptimizer),
20+
("optimize_transpose", TransposeOptimizer),
2121
("fold_constants", ConstFoldOptimizer),
22-
# merge_duplication should be used after reduce_transpose
23-
# for reduce_transpose may have some trans nodes that can be merge
22+
# merge_duplication should be used after optimize_transpose
23+
# for optimize_transpose may have some trans nodes that can be merge
2424
("merge_duplication", MergeDuplicatedNodesOptimizer),
25-
("reduce_identity", IdentityOptimizer),
25+
("remove_identity", IdentityOptimizer),
2626
])
2727

2828

tf2onnx/optimizer/optimizer_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def is_debug_mode(self):
2626
return utils.is_debug_mode()
2727

2828
def optimize(self, graph):
29-
""" Optimize graph, return optimized graph """
29+
""" Optimize graph, return optimized graph. """
3030
before = graph.dump_node_statistics()
3131

3232
graph = self._optimize(graph)
@@ -38,6 +38,7 @@ def optimize(self, graph):
3838
return graph
3939

4040
def _optimize(self, graph):
41+
""" Derived class should override this function. """
4142
raise NotImplementedError
4243

4344
@staticmethod

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,9 @@ def merge_duplicated_transposes(self):
129129
# dangling transpose nodes can be deleted
130130
graph.delete_unused_nodes(graph.outputs)
131131

132-
def optimize(self, graph):
132+
def _optimize(self, graph):
133133
self._g = graph
134134
self.pre_optimize_action()
135-
previous_counter = self._g.dump_node_statistics()
136135
no_action = False
137136
iteration_cnt = 0
138137
while not no_action:
@@ -161,10 +160,6 @@ def optimize(self, graph):
161160

162161
self.merge_duplicated_transposes()
163162
self.post_optimize_action()
164-
165-
current_counter = self._g.dump_node_statistics()
166-
transpose_cnt = current_counter["Transpose"]
167-
self._print_stat_diff(previous_counter, current_counter)
168163
return self._g
169164

170165
def _initialize_handlers(self):

0 commit comments

Comments
 (0)