@@ -42,40 +42,38 @@ def rewrite_constant_fold(g, ops):
42
42
tensorflow missed something, make another pass over the graph and fix want we care about.
43
43
"""
44
44
func_map = {
45
- # "Add": np.add,
46
- # "GreaterEqual": np.greater_equal,
47
- # "Cast": np.cast,
45
+ "Add" : np .add ,
46
+ "GreaterEqual" : np .greater_equal ,
47
+ "Cast" : np .cast ,
48
48
"ConcatV2" : np .concatenate ,
49
- # "Less": np.less,
50
- # "ListDiff": np.setdiff1d,
51
- # "Mul": np.multiply,
52
- # "Pack": np.stack,
53
- # "Range": np.arange,
54
- # "Sqrt": np.sqrt,
55
- # "Sub": np.subtract,
49
+ "Less" : np .less ,
50
+ "ListDiff" : np .setdiff1d ,
51
+ "Mul" : np .multiply ,
52
+ "Pack" : np .stack ,
53
+ "Range" : np .arange ,
54
+ "Sqrt" : np .sqrt ,
55
+ "Sub" : np .subtract ,
56
56
}
57
- ref_cnt_per_node = {}
58
- for idx , op in enumerate (ops ):
59
- for op_input in op .inputs :
60
- if op_input .name not in ref_cnt_per_node :
61
- ref_cnt_per_node [op_input .name ] = 0
62
- ref_cnt_per_node [op_input .name ] += 1
63
57
64
58
# pylint: disable=too-many-nested-blocks
65
59
keep_looking = True
66
60
while keep_looking :
67
61
keep_looking = False
68
62
for idx , op in enumerate (ops ):
69
63
func = func_map .get (op .type )
70
- if func is None :
71
- continue
64
+ if func is None : continue
65
+ if set ( op . output ) & set ( g . outputs ): continue
72
66
try :
73
67
inputs = []
68
+ skip = False
74
69
for node in op .inputs :
75
70
if not node .is_const ():
71
+ skip = True
76
72
break
77
73
inputs .append (node .get_tensor_value (as_list = False ))
78
74
75
+ if skip : continue
76
+
79
77
logger .debug ("op name %s, %s, %s" , op .name , len (op .input ), len (inputs ))
80
78
if inputs and len (op .input ) == len (inputs ):
81
79
logger .info ("folding node type=%s, name=%s" % (op .type , op .name ))
@@ -109,18 +107,15 @@ def rewrite_constant_fold(g, ops):
109
107
old_node_name = op .name
110
108
logger .debug ("create const node [%s] replacing [%s]" , new_node_name , old_node_name )
111
109
ops [idx ] = g .make_const (new_node_name , val )
112
- ref_cnt_per_node [new_node_name ] = ref_cnt_per_node [old_node_name ]
113
110
114
111
logger .debug ("replace old output [%s] with new output [%s]" , old_output_name , new_output_name )
115
112
# need to re-write the consumers input name to use the const name
116
113
consumers = g .find_output_consumers (old_output_name )
117
114
if consumers :
118
115
for consumer in consumers :
119
116
g .replace_input (consumer , old_output_name , new_output_name )
120
- for node in op .inputs :
121
- ref_cnt_per_node [node .name ] -= 1
122
- if ref_cnt_per_node [node .name ] == 0 :
123
- g .remove_node (node .name )
117
+ g .remove_node (old_node_name )
118
+
124
119
# keep looking until there is nothing we can fold.
125
120
# We keep the graph in topological order so if we folded,
126
121
# the result might help a following op.
0 commit comments