Skip to content

Commit 81c872b

Browse files
committed
refactor code
1 parent b75d5bf commit 81c872b

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring
1313

14+
# key is op_type, value is the function to compute outputs
15+
# the schema of function is: inputs are(node, graph), output is a list of constant values.
1416
_func_map = {}
1517

1618

@@ -43,7 +45,7 @@ def _optimize_at_current_graph_level(self, graph):
4345

4446
@staticmethod
4547
def _should_skip(node):
46-
# only support onnx official op for now, other op such as contrib op not supported for now
48+
# only support onnx official op for now, op in other domain is not supported for now
4749
if not utils.is_onnx_domain(node.domain):
4850
return True
4951

@@ -63,10 +65,10 @@ def _fold_node(self, node, graph):
6365
if self._all_inputs_are_const(node.inputs) and not self._is_graph_output(node, graph):
6466
process_func = _func_map.get(node.type, None)
6567
if process_func:
66-
const_val_after_trans = process_func(node, graph)
67-
self._replace_node_with_const(node, graph, const_val_after_trans)
68+
const_outputs = process_func(node, graph)
69+
self._replace_node_with_const(node, graph, const_outputs)
6870
return True
69-
self.log.warning("need to add function to fold op %s whose op_type is %s", node.name, node.type)
71+
self.log.debug("need to add function to fold op %s whose op_type is %s", node.name, node.type)
7072
return False
7173

7274
@staticmethod

tf2onnx/optimizer/optimizer_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def optimize(self, graph):
2020
original_node_statistics = graph.dump_node_statistics()
2121
graph = self._optimize(graph)
2222
graph.delete_unused_nodes(graph.outputs)
23-
graph.topological_sort(graph.get_nodes())
2423
node_statistics = graph.dump_node_statistics()
2524
self._print_stat_diff(original_node_statistics, node_statistics)
2625
return graph

0 commit comments

Comments
 (0)