Skip to content

Commit c348562

Browse files
committed
refactor
1 parent f850cbe commit c348562

File tree

6 files changed

+26
-23
lines changed

6 files changed

+26
-23
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from common import get_test_config
1919
from tf2onnx import utils
2020
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
21-
from tf2onnx.graph import GraphUtil
21+
from tf2onnx import optimizer
2222

2323

2424
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -141,7 +141,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
141141
with tf.Session() as sess:
142142
g = process_tf_graph(sess.graph, opset=self.config.opset, output_names=output_names_with_port,
143143
target=self.config.target, **process_args)
144-
g = GraphUtil.optimize_graph(g)
144+
g = optimizer.optimize_graph(g)
145145
actual = self._run_backend(g, output_names_with_port, onnx_feed_dict)
146146

147147
for expected_val, actual_val in zip(expected, actual):

tests/run_pretrained_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import tf2onnx
3131
from tf2onnx import loader
3232
from tf2onnx import utils
33-
from tf2onnx.graph import GraphUtil
33+
from tf2onnx import optimizer
3434
from tf2onnx.tfonnx import process_tf_graph
3535

3636
# pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda
@@ -270,7 +270,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
270270
onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override,
271271
input_names=inputs.keys())
272272
model_proto = onnx_graph.make_model("converted from tf2onnx")
273-
new_model_proto = GraphUtil.optimize_graph(onnx_graph, debug=debug).make_model("optimized")
273+
new_model_proto = optimizer.optimize_graph(onnx_graph, debug=debug).make_model("optimized")
274274
if new_model_proto:
275275
model_proto = new_model_proto
276276
else:

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@
88
from __future__ import unicode_literals
99
import logging
1010

11-
logging.basicConfig(level=logging.INFO)
11+
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
12+
13+
1214
log = logging.getLogger("tf2onnx.optimizer.identity_optimizer")
1315

1416

1517
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable
1618

1719

18-
class IdentityOptimizer(object):
20+
class IdentityOptimizer(GraphOptimizerBase):
1921
"""Identity Optimizer."""
2022

2123
def __init__(self, debug=False):
22-
self._debug = debug
24+
super(IdentityOptimizer, self).__init__(debug)
25+
self._name = "IdentityOptimizer"
26+
2327
self._g = None
2428

2529
def optimize(self, graph):

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from collections import defaultdict, namedtuple
11+
import logging
1112

1213
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1314

@@ -18,11 +19,14 @@
1819
class MergeDuplicatedNodesOptimizer(GraphOptimizerBase):
1920
"""Remove duplicate nodes.
2021
"""
21-
_key_to_group_nodes = namedtuple("key", "type input")
22-
23-
def __init__(self, name="MergeDuplicatedNodesOptimizer", debug=False):
24-
super(MergeDuplicatedNodesOptimizer, self).__init__(name=name)
25-
22+
_KeyToGroupNodes = namedtuple("key", "type input")
23+
24+
def __init__(self, debug=False):
25+
super(MergeDuplicatedNodesOptimizer, self).__init__(debug)
26+
# optimizer should have name and log property
27+
self._name = "MergeDuplicatedNodesOptimizer"
28+
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
29+
# used internally
2630
self._graph_can_be_optimized = True
2731

2832
def _optimize(self, graph):
@@ -46,7 +50,7 @@ def _merge_duplicated_nodes(self, graph):
4650
def _group_nodes_by_type_inputs(self, graph):
4751
res = defaultdict(list)
4852
for node in graph.get_nodes():
49-
res[self._key_to_group_nodes(node.type, tuple(node.input))].append(node)
53+
res[self._KeyToGroupNodes(node.type, tuple(node.input))].append(node)
5054
return res
5155

5256
def _del_nodes_if_duplicated(self, nodes_group, graph):

tf2onnx/optimizer/optimizer_base.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,13 @@
55

66
from __future__ import unicode_literals
77

8-
import logging
9-
10-
from tf2onnx import utils
11-
128

139
class GraphOptimizerBase(object):
1410
"""optimizer graph to improve performance
1511
"""
1612

17-
def __init__(self, name="GraphOptimizerBase", debug=False):
13+
def __init__(self, debug=False):
1814
self._debug = debug
19-
utils.make_sure(name is not None, "name of optimizer is needed.")
20-
self._name = name
21-
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
2215

2316
def optimize(self, graph):
2417
original_node_statistics = graph.dump_node_statistics()

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212

1313
from tf2onnx import utils
14+
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1415

1516
logging.basicConfig(level=logging.INFO)
1617
log = logging.getLogger("tf2onnx.optimizer.transpose_optimizer")
@@ -35,11 +36,12 @@ def is_useless_transpose(transpose_node):
3536
return transpose_node.type == "Transpose" and perm_attr and perm_attr.ints == list(range(len(perm_attr.ints)))
3637

3738

38-
class TransposeOptimizer(object):
39+
class TransposeOptimizer(GraphOptimizerBase):
3940
"""Transpose Optimizer."""
4041

4142
def __init__(self, debug=False):
42-
self._debug = debug
43+
super(TransposeOptimizer, self).__init__(debug)
44+
self._name = "TransposeOptimizer"
4345
self._handler_map = {}
4446
self._force_stop = {}
4547

0 commit comments

Comments
 (0)