13
13
from tf2onnx import utils
14
14
from tf2onnx .optimizer .optimizer_base import GraphOptimizerBase
15
15
16
- logging .basicConfig (level = logging .INFO )
17
- log = logging .getLogger ("tf2onnx.optimizer.transpose_optimizer" )
18
-
19
16
20
17
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,abstract-method
21
18
# FIXME:
@@ -41,6 +38,7 @@ class TransposeOptimizer(GraphOptimizerBase):
41
38
42
39
def __init__ (self , debug = False ):
43
40
super (TransposeOptimizer , self ).__init__ ("TransposeOptimizer" , debug )
41
+ self ._log = logging .getLogger ("tf2onnx.optimizer.%s" % self ._name )
44
42
self ._handler_map = {}
45
43
self ._force_stop = {}
46
44
@@ -158,17 +156,17 @@ def optimize(self, graph):
158
156
if "stop" in self ._force_stop and self ._force_stop ["stop" ] == 1 :
159
157
break
160
158
161
- log .debug ("finish after " + str (iteration_cnt ) + " iteration(s)" )
159
+ self . _log .debug ("finish after " + str (iteration_cnt ) + " iteration(s)" )
162
160
163
161
self .merge_duplicated_transposes ()
164
162
self .post_optimize_action ()
165
163
166
164
current_counter = self ._g .dump_node_statistics ()
167
165
transpose_cnt = current_counter ["Transpose" ]
168
- current_counter . subtract ( previous_counter )
169
- log . info ( " %d transpose op(s) left, ops diff after transpose optimization: %s" , transpose_cnt , current_counter )
166
+ self . _log . info ( " %d transpose op(s) left" , transpose_cnt )
167
+ self . _print_stat_diff ( previous_counter , current_counter )
170
168
if transpose_cnt > 2 :
171
- log .warning ("please try add --fold_const to help remove more transpose" )
169
+ self . _log .warning ("please try add --fold_const to help remove more transpose" )
172
170
return self ._g
173
171
174
172
def _initialize_handlers (self ):
@@ -219,7 +217,7 @@ def _handle_node_having_branches(self, node):
219
217
self ._g .remove_node (n .name )
220
218
return True
221
219
222
- log .debug ("input transpose does not have single consumer, skipping..." )
220
+ self . _log .debug ("input transpose does not have single consumer, skipping..." )
223
221
return False
224
222
225
223
# get the input index of transpose op in node's inputs.
@@ -263,7 +261,7 @@ def _handle_nhwc_tranpose(self, trans):
263
261
if len (out_nodes ) == 1 :
264
262
p = out_nodes [0 ]
265
263
if p .name in self ._output_names :
266
- log .debug ("cannot move transpose down since it met output node %s" , p .name )
264
+ self . _log .debug ("cannot move transpose down since it met output node %s" , p .name )
267
265
return False
268
266
269
267
if p .type in self ._handler_map :
0 commit comments