Skip to content

Commit bd20cc4

Browse files
committed
refine identity optimizer's log
1 parent e8ec761 commit bd20cc4

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010

1111
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1212

13-
14-
log = logging.getLogger("tf2onnx.optimizer.identity_optimizer")
15-
16-
1713
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
1814

1915

@@ -22,7 +18,7 @@ class IdentityOptimizer(GraphOptimizerBase):
2218

2319
def __init__(self, debug=False):
2420
super(IdentityOptimizer, self).__init__("IdentityOptimizer", debug)
25-
21+
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
2622
self._g = None
2723

2824
def optimize(self, graph):
@@ -31,8 +27,8 @@ def optimize(self, graph):
3127
self._optimize_recursively(self._g)
3228
current_counter = self._g.dump_node_statistics()
3329
identity_cnt = current_counter["Identity"]
34-
current_counter.subtract(previous_counter)
35-
log.info(" %d identity op(s) left, ops diff after identity optimization: %s", identity_cnt, current_counter)
30+
self._log.info(" %d identity op(s) left", identity_cnt)
31+
self._print_stat_diff(previous_counter, current_counter)
3632
return self._g
3733

3834
def _optimize_recursively(self, g):
@@ -42,9 +38,9 @@ def _optimize_recursively(self, g):
4238
body_graphs = n.get_body_graphs()
4339
if body_graphs:
4440
for attr, b_g in body_graphs.items():
45-
log.debug("start handling subgraph of %s's attribute %s", n.name, attr)
41+
self._log.debug("start handling subgraph of %s's attribute %s", n.name, attr)
4642
self._optimize_recursively(b_g)
47-
log.debug("finish handling subgraph of %s's attribute %s", n.name, attr)
43+
self._log.debug("finish handling subgraph of %s's attribute %s", n.name, attr)
4844

4945
def _optimize(self, g):
5046
has_update = True
@@ -53,7 +49,7 @@ def _optimize(self, g):
5349
nodes = [n for n in g.get_nodes() if n.type == "Identity"]
5450
for n in nodes:
5551
if n.graph is None:
56-
log.info("node has been removed from this graph, skip")
52+
self._log.info("node has been removed from this graph, skip")
5753
continue
5854

5955
graph_outputs = set(n.output).intersection(g.outputs)
@@ -72,19 +68,18 @@ def _handle_non_graph_output_identity(graph, identity):
7268
graph.remove_node(identity.name)
7369
return True
7470

75-
@staticmethod
76-
def _handle_graph_output_identity(graph, identity, graph_outputs):
71+
def _handle_graph_output_identity(self, graph, identity, graph_outputs):
7772
input_id = identity.input[0]
7873
input_node = identity.inputs[0]
7974

8075
if input_node.graph != graph:
8176
# If input node is in parent graph, we don't handle it now
82-
log.debug("input node in parent graph, skip")
77+
self._log.debug("input node in parent graph, skip")
8378
return False
8479

8580
if input_node.is_graph_input():
8681
# Identity between input and output should not be removed.
87-
log.debug("skip identity between input and output")
82+
self._log.debug("skip identity between input and output")
8883
return False
8984

9085
output_id = identity.output[0]
@@ -93,7 +88,7 @@ def _handle_graph_output_identity(graph, identity, graph_outputs):
9388
if input_id in graph.outputs:
9489
# input id already be graph output, so we cannot make that be another graph output.
9590
# this Identity must be kept.
96-
log.debug("identity input already be graph output")
91+
self._log.debug("identity input already be graph output")
9792
return False
9893

9994
graph.remove_node(identity.name)

0 commit comments

Comments
 (0)