Skip to content

Commit 4fabf2d

Browse files
committed
refactor code according to review comments
1 parent 4326e02 commit 4fabf2d

File tree

6 files changed

+54
-41
lines changed

6 files changed

+54
-41
lines changed

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
for example, input of transpose node is const then we can do transpose statically instead of at runtime
77
"""
88

9-
import logging
10-
119
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1210
from tf2onnx import utils
1311

@@ -16,9 +14,9 @@
1614
_func_map = {}
1715

1816

19-
def _register_func(onnx_op):
17+
def _register_func(op_type):
2018
def _internal_fun(func):
21-
_func_map[onnx_op] = func
19+
_func_map[op_type] = func
2220
return func
2321
return _internal_fun
2422

@@ -27,7 +25,6 @@ class ConstFoldOptimizer(GraphOptimizerBase):
2725

2826
def __init__(self, debug=False):
2927
super(ConstFoldOptimizer, self).__init__("ConstFoldOptimizer", debug)
30-
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
3128

3229
def _optimize(self, graph):
3330
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
@@ -38,14 +35,18 @@ def _optimize_at_current_graph_level(self, graph):
3835
graph_changed = False
3936
ops = graph.get_nodes()
4037
for op in ops:
41-
if self._can_skip(op):
38+
if self._should_skip(op):
4239
continue
4340
if self._fold_node(op, graph):
4441
graph_changed = True
4542
return graph
4643

4744
@staticmethod
48-
def _can_skip(node):
45+
def _should_skip(node):
46+
# only support onnx official op for now, other op such as contrib op not supported for now
47+
if not utils.is_onnx_domain(node.domain):
48+
return True
49+
4950
if node.is_const() or node.is_graph_input():
5051
return True
5152

@@ -60,9 +61,13 @@ def _fold_node(self, node, graph):
6061
if node can be fold True will be return indicating that graph is changed
6162
"""
6263
if self._all_inputs_are_const(node.inputs) and not self._is_graph_output(node, graph):
63-
process_func = _func_map.get(node.type, self._try_fold)
64-
return process_func(node, graph)
65-
64+
process_func = _func_map.get(node.type, None)
65+
if process_func:
66+
const_val_after_trans = process_func(node, graph)
67+
self._replace_node_with_const(node, graph, const_val_after_trans)
68+
return True
69+
else:
70+
self.log.warning("need to add function to fold op %s whose op_type is %s", node.name, node.type)
6671
return False
6772

6873
@staticmethod
@@ -76,21 +81,20 @@ def _is_graph_output(node, graph):
7681
return node_out_set.intersection(graph_out_set)
7782

7883
@staticmethod
79-
def _replace_node_with_const(node, graph, val):
80-
const_node = graph.make_const(utils.make_name("const_fold_opt"), val)
81-
graph.replace_all_inputs(graph.get_nodes(), node.output[0], const_node.output[0])
84+
def _replace_node_with_const(node, graph, vals):
85+
utils.make_sure(len(node.output) == len(vals), "length of node outputs and const vals should be same")
86+
for old_input, val in zip(node.output, vals):
87+
const_node = graph.make_const(utils.make_name("const_fold_opt"), val)
88+
graph.set_dtype(const_node.output[0], utils.map_numpy_to_onnx_dtype(val.dtype))
89+
graph.set_shape(const_node.output[0], val.shape)
90+
graph.replace_all_inputs(graph.get_nodes(), old_input, const_node.output[0])
8291
graph.remove_node(node.name)
8392

8493
@staticmethod
85-
@_register_func(onnx_op="Transpose")
86-
def _fold_transpose(node, graph):
94+
@_register_func("Transpose")
95+
def _fold_transpose(node, graph) -> list:
8796
const_val = node.inputs[0].get_tensor_value(as_list=False)
8897
perm_attr = node.get_attr("perm")
8998
perm = perm_attr.ints if perm_attr else None
9099
const_val_after_trans = const_val.transpose(perm)
91-
ConstFoldOptimizer._replace_node_with_const(node, graph, const_val_after_trans)
92-
return True
93-
94-
def _try_fold(self, node, graph):
95-
self._log.warning("need to add function to fold op %s whose op_type is %s", node.name, node.type)
96-
return False
100+
return [const_val_after_trans]

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
from __future__ import unicode_literals
9-
import logging
109

1110
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1211

@@ -18,7 +17,6 @@ class IdentityOptimizer(GraphOptimizerBase):
1817

1918
def __init__(self, debug=False):
2019
super(IdentityOptimizer, self).__init__("IdentityOptimizer", debug)
21-
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
2220
self._g = None
2321

2422
def optimize(self, graph):
@@ -27,7 +25,7 @@ def optimize(self, graph):
2725
self._optimize_recursively(self._g)
2826
current_counter = self._g.dump_node_statistics()
2927
identity_cnt = current_counter["Identity"]
30-
self._log.info(" %d identity op(s) left", identity_cnt)
28+
self.log.info(" %d identity op(s) left", identity_cnt)
3129
self._print_stat_diff(previous_counter, current_counter)
3230
return self._g
3331

@@ -38,9 +36,9 @@ def _optimize_recursively(self, g):
3836
body_graphs = n.get_body_graphs()
3937
if body_graphs:
4038
for attr, b_g in body_graphs.items():
41-
self._log.debug("start handling subgraph of %s's attribute %s", n.name, attr)
39+
self.log.debug("start handling subgraph of %s's attribute %s", n.name, attr)
4240
self._optimize_recursively(b_g)
43-
self._log.debug("finish handling subgraph of %s's attribute %s", n.name, attr)
41+
self.log.debug("finish handling subgraph of %s's attribute %s", n.name, attr)
4442

4543
def _optimize(self, g):
4644
has_update = True
@@ -49,7 +47,7 @@ def _optimize(self, g):
4947
nodes = [n for n in g.get_nodes() if n.type == "Identity"]
5048
for n in nodes:
5149
if n.graph is None:
52-
self._log.info("node has been removed from this graph, skip")
50+
self.log.info("node has been removed from this graph, skip")
5351
continue
5452

5553
graph_outputs = set(n.output).intersection(g.outputs)
@@ -74,12 +72,12 @@ def _handle_graph_output_identity(self, graph, identity, graph_outputs):
7472

7573
if input_node.graph != graph:
7674
# If input node is in parent graph, we don't handle it now
77-
self._log.debug("input node in parent graph, skip")
75+
self.log.debug("input node in parent graph, skip")
7876
return False
7977

8078
if input_node.is_graph_input():
8179
# Identity between input and output should not be removed.
82-
self._log.debug("skip identity between input and output")
80+
self.log.debug("skip identity between input and output")
8381
return False
8482

8583
output_id = identity.output[0]
@@ -88,7 +86,7 @@ def _handle_graph_output_identity(self, graph, identity, graph_outputs):
8886
if input_id in graph.outputs:
8987
# input id already be graph output, so we cannot make that be another graph output.
9088
# this Identity must be kept.
91-
self._log.debug("identity input already be graph output")
89+
self.log.debug("identity input already be graph output")
9290
return False
9391

9492
graph.remove_node(identity.name)

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
from collections import defaultdict, namedtuple
11-
import logging
1211

1312
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1413

@@ -23,7 +22,6 @@ class MergeDuplicatedNodesOptimizer(GraphOptimizerBase):
2322
"""
2423
def __init__(self, debug=False):
2524
super(MergeDuplicatedNodesOptimizer, self).__init__("MergeDuplicatedNodesOptimizer", debug)
26-
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
2725
# used internally
2826
self._graph_can_be_optimized = True
2927

tf2onnx/optimizer/optimizer_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""Graph Optimizer Base"""
55

66
from __future__ import unicode_literals
7+
import logging
78

89

910
class GraphOptimizerBase(object):
@@ -13,6 +14,7 @@ class GraphOptimizerBase(object):
1314
def __init__(self, name, debug=False):
1415
self._debug = debug
1516
self._name = name
17+
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
1618

1719
def optimize(self, graph):
1820
original_node_statistics = graph.dump_node_statistics()
@@ -30,6 +32,10 @@ def _optimize(self, graph):
3032
def name(self):
3133
return self._name
3234

35+
@property
36+
def log(self):
37+
return self._log
38+
3339
@staticmethod
3440
def _apply_optimization(graph, optimize_func):
3541
"""
@@ -54,4 +60,4 @@ def _print_stat_diff(self, nodes_original, nodes_after_optimized):
5460
for key, value in nodes_after_optimized.items():
5561
if value != 0:
5662
res[key] = value
57-
self._log.info("the optimization gain is %s", res)
63+
self.log.info("the optimization gain is %s", res)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from __future__ import unicode_literals
77
from collections import defaultdict
88

9-
import logging
109

1110
import numpy as np
1211

@@ -38,7 +37,7 @@ class TransposeOptimizer(GraphOptimizerBase):
3837

3938
def __init__(self, debug=False):
4039
super(TransposeOptimizer, self).__init__("TransposeOptimizer", debug)
41-
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
40+
4241
self._handler_map = {}
4342
self._force_stop = {}
4443

@@ -156,17 +155,17 @@ def optimize(self, graph):
156155
if "stop" in self._force_stop and self._force_stop["stop"] == 1:
157156
break
158157

159-
self._log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
158+
self.log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
160159

161160
self.merge_duplicated_transposes()
162161
self.post_optimize_action()
163162

164163
current_counter = self._g.dump_node_statistics()
165164
transpose_cnt = current_counter["Transpose"]
166-
self._log.info(" %d transpose op(s) left", transpose_cnt)
165+
self.log.info(" %d transpose op(s) left", transpose_cnt)
167166
self._print_stat_diff(previous_counter, current_counter)
168167
if transpose_cnt > 2:
169-
self._log.warning("please try add --fold_const to help remove more transpose")
168+
self.log.warning("please try add --fold_const to help remove more transpose")
170169
return self._g
171170

172171
def _initialize_handlers(self):
@@ -217,7 +216,7 @@ def _handle_node_having_branches(self, node):
217216
self._g.remove_node(n.name)
218217
return True
219218

220-
self._log.debug("input transpose does not have single consumer, skipping...")
219+
self.log.debug("input transpose does not have single consumer, skipping...")
221220
return False
222221

223222
# get the input index of transpose op in node's inputs.
@@ -255,13 +254,13 @@ def _switch_transpose_and_node(self, node, trans):
255254
# otherwise, it means that we skip handling since it is not in our support set
256255
def _handle_nhwc_tranpose(self, trans):
257256
if trans.output[0] in self._g.outputs:
258-
log.debug("%s connects to graph outputs, skip", trans.output[0])
257+
self.log.debug("%s connects to graph outputs, skip", trans.output[0])
259258
return False
260259
out_nodes = self._g.find_output_consumers(trans.output[0])
261260
if len(out_nodes) == 1:
262261
p = out_nodes[0]
263262
if p.name in self._output_names:
264-
self._log.debug("cannot move transpose down since it met output node %s", p.name)
263+
self.log.debug("cannot move transpose down since it met output node %s", p.name)
265264
return False
266265

267266
if p.type in self._handler_map:

tf2onnx/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,14 @@ def get_onnx_version():
428428
return onnx.__version__
429429

430430

431+
431432
def make_opsetid(domain, version):
432433
make_sure(isinstance(version, int), "version must be an integer")
433434
return helper.make_opsetid(domain, version)
435+
436+
437+
def is_onnx_domain(domain):
438+
if domain is None or domain == "":
439+
return True
440+
return False
441+

0 commit comments

Comments
 (0)