10
10
11
11
from tf2onnx .optimizer .optimizer_base import GraphOptimizerBase
12
12
13
-
14
- log = logging .getLogger ("tf2onnx.optimizer.identity_optimizer" )
15
-
16
-
17
13
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
18
14
19
15
@@ -22,7 +18,7 @@ class IdentityOptimizer(GraphOptimizerBase):
22
18
23
19
def __init__ (self , debug = False ):
24
20
super (IdentityOptimizer , self ).__init__ ("IdentityOptimizer" , debug )
25
-
21
+ self . _log = logging . getLogger ( "tf2onnx.optimizer.%s" % self . _name )
26
22
self ._g = None
27
23
28
24
def optimize (self , graph ):
@@ -31,8 +27,8 @@ def optimize(self, graph):
31
27
self ._optimize_recursively (self ._g )
32
28
current_counter = self ._g .dump_node_statistics ()
33
29
identity_cnt = current_counter ["Identity" ]
34
- current_counter . subtract ( previous_counter )
35
- log . info ( " %d identity op(s) left, ops diff after identity optimization: %s" , identity_cnt , current_counter )
30
+ self . _log . info ( " %d identity op(s) left" , identity_cnt )
31
+ self . _print_stat_diff ( previous_counter , current_counter )
36
32
return self ._g
37
33
38
34
def _optimize_recursively (self , g ):
@@ -42,9 +38,9 @@ def _optimize_recursively(self, g):
42
38
body_graphs = n .get_body_graphs ()
43
39
if body_graphs :
44
40
for attr , b_g in body_graphs .items ():
45
- log .debug ("start handling subgraph of %s's attribute %s" , n .name , attr )
41
+ self . _log .debug ("start handling subgraph of %s's attribute %s" , n .name , attr )
46
42
self ._optimize_recursively (b_g )
47
- log .debug ("finish handling subgraph of %s's attribute %s" , n .name , attr )
43
+ self . _log .debug ("finish handling subgraph of %s's attribute %s" , n .name , attr )
48
44
49
45
def _optimize (self , g ):
50
46
has_update = True
@@ -53,7 +49,7 @@ def _optimize(self, g):
53
49
nodes = [n for n in g .get_nodes () if n .type == "Identity" ]
54
50
for n in nodes :
55
51
if n .graph is None :
56
- log .info ("node has been removed from this graph, skip" )
52
+ self . _log .info ("node has been removed from this graph, skip" )
57
53
continue
58
54
59
55
graph_outputs = set (n .output ).intersection (g .outputs )
@@ -72,19 +68,18 @@ def _handle_non_graph_output_identity(graph, identity):
72
68
graph .remove_node (identity .name )
73
69
return True
74
70
75
- @staticmethod
76
- def _handle_graph_output_identity (graph , identity , graph_outputs ):
71
+ def _handle_graph_output_identity (self , graph , identity , graph_outputs ):
77
72
input_id = identity .input [0 ]
78
73
input_node = identity .inputs [0 ]
79
74
80
75
if input_node .graph != graph :
81
76
# If input node is in parent graph, we don't handle it now
82
- log .debug ("input node in parent graph, skip" )
77
+ self . _log .debug ("input node in parent graph, skip" )
83
78
return False
84
79
85
80
if input_node .is_graph_input ():
86
81
# Identity between input and output should not be removed.
87
- log .debug ("skip identity between input and output" )
82
+ self . _log .debug ("skip identity between input and output" )
88
83
return False
89
84
90
85
output_id = identity .output [0 ]
@@ -93,7 +88,7 @@ def _handle_graph_output_identity(graph, identity, graph_outputs):
93
88
if input_id in graph .outputs :
94
89
# input id already be graph output, so we cannot make that be another graph output.
95
90
# this Identity must be kept.
96
- log .debug ("identity input already be graph output" )
91
+ self . _log .debug ("identity input already be graph output" )
97
92
return False
98
93
99
94
graph .remove_node (identity .name )
0 commit comments