Skip to content

Commit 469435b

Browse files
committed
refactor
1 parent f13d8b6 commit 469435b

File tree

4 files changed

+10
-8
lines changed

4 files changed

+10
-8
lines changed

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ class IdentityOptimizer(GraphOptimizerBase):
2121
"""Identity Optimizer."""
2222

2323
def __init__(self, debug=False):
24-
super(IdentityOptimizer, self).__init__(debug)
25-
self._name = "IdentityOptimizer"
24+
super(IdentityOptimizer, self).__init__("IdentityOptimizer", debug)
2625

2726
self._g = None
2827

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
_KeyToGroupNodes = namedtuple("key", "type input")
1919

20+
2021
class MergeDuplicatedNodesOptimizer(GraphOptimizerBase):
2122
"""Remove duplicate nodes.
2223
"""
2324
def __init__(self, debug=False):
24-
super(MergeDuplicatedNodesOptimizer, self).__init__(debug)
25-
# optimizer should have name and log property
26-
self._name = "MergeDuplicatedNodesOptimizer"
25+
super(MergeDuplicatedNodesOptimizer, self).__init__("MergeDuplicatedNodesOptimizer", debug)
2726
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
2827
# used internally
2928
self._graph_can_be_optimized = True

tf2onnx/optimizer/optimizer_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ class GraphOptimizerBase(object):
1010
"""optimizer graph to improve performance
1111
"""
1212

13-
def __init__(self, debug=False):
13+
def __init__(self, name, debug=False):
1414
self._debug = debug
15+
self._name = name
1516

1617
def optimize(self, graph):
1718
original_node_statistics = graph.dump_node_statistics()
@@ -23,6 +24,10 @@ def optimize(self, graph):
2324
def _optimize(self, graph):
2425
raise NotImplementedError
2526

27+
@property
28+
def name(self):
29+
return self._name
30+
2631
@staticmethod
2732
def _apply_optimization(graph, optimize_func):
2833
"""

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ class TransposeOptimizer(GraphOptimizerBase):
4040
"""Transpose Optimizer."""
4141

4242
def __init__(self, debug=False):
43-
super(TransposeOptimizer, self).__init__(debug)
44-
self._name = "TransposeOptimizer"
43+
super(TransposeOptimizer, self).__init__("TransposeOptimizer", debug)
4544
self._handler_map = {}
4645
self._force_stop = {}
4746

0 commit comments

Comments
 (0)