Skip to content

Commit f2289e3

Browse files
committed
refine transpose optimizer's log
1 parent bd20cc4 commit f2289e3

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

tf2onnx/optimizer/optimizer_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ def _print_stat_diff(self, nodes_original, nodes_after_optimized):
5252
for key, value in nodes_after_optimized.items():
5353
if value != 0:
5454
res[key] = value
55-
self._log.info("after optimized, the optimization_statistics is %s", res)
55+
self._log.info("the optimization gain is %s", res)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from tf2onnx import utils
1414
from tf2onnx.optimizer.optimizer_base import GraphOptimizerBase
1515

16-
logging.basicConfig(level=logging.INFO)
17-
log = logging.getLogger("tf2onnx.optimizer.transpose_optimizer")
18-
1916

2017
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,abstract-method
2118
# FIXME:
@@ -41,6 +38,7 @@ class TransposeOptimizer(GraphOptimizerBase):
4138

4239
def __init__(self, debug=False):
4340
super(TransposeOptimizer, self).__init__("TransposeOptimizer", debug)
41+
self._log = logging.getLogger("tf2onnx.optimizer.%s" % self._name)
4442
self._handler_map = {}
4543
self._force_stop = {}
4644

@@ -158,17 +156,17 @@ def optimize(self, graph):
158156
if "stop" in self._force_stop and self._force_stop["stop"] == 1:
159157
break
160158

161-
log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
159+
self._log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
162160

163161
self.merge_duplicated_transposes()
164162
self.post_optimize_action()
165163

166164
current_counter = self._g.dump_node_statistics()
167165
transpose_cnt = current_counter["Transpose"]
168-
current_counter.subtract(previous_counter)
169-
log.info(" %d transpose op(s) left, ops diff after transpose optimization: %s", transpose_cnt, current_counter)
166+
self._log.info(" %d transpose op(s) left", transpose_cnt)
167+
self._print_stat_diff(previous_counter, current_counter)
170168
if transpose_cnt > 2:
171-
log.warning("please try add --fold_const to help remove more transpose")
169+
self._log.warning("please try add --fold_const to help remove more transpose")
172170
return self._g
173171

174172
def _initialize_handlers(self):
@@ -219,7 +217,7 @@ def _handle_node_having_branches(self, node):
219217
self._g.remove_node(n.name)
220218
return True
221219

222-
log.debug("input transpose does not have single consumer, skipping...")
220+
self._log.debug("input transpose does not have single consumer, skipping...")
223221
return False
224222

225223
# get the input index of transpose op in node's inputs.
@@ -263,7 +261,7 @@ def _handle_nhwc_tranpose(self, trans):
263261
if len(out_nodes) == 1:
264262
p = out_nodes[0]
265263
if p.name in self._output_names:
266-
log.debug("cannot move transpose down since it met output node %s", p.name)
264+
self._log.debug("cannot move transpose down since it met output node %s", p.name)
267265
return False
268266

269267
if p.type in self._handler_map:

0 commit comments

Comments
 (0)