@@ -54,21 +54,15 @@ def rewrite_constant_fold(g, ops):
54
54
"Sqrt" : np .sqrt ,
55
55
"Sub" : np .subtract ,
56
56
}
57
- ref_cnt_per_node = {}
58
- for idx , op in enumerate (ops ):
59
- for op_input in op .inputs :
60
- if op_input .name not in ref_cnt_per_node :
61
- ref_cnt_per_node [op_input .name ] = 0
62
- ref_cnt_per_node [op_input .name ] += 1
63
57
64
58
# pylint: disable=too-many-nested-blocks
65
59
keep_looking = True
66
60
while keep_looking :
67
61
keep_looking = False
68
62
for idx , op in enumerate (ops ):
69
63
func = func_map .get (op .type )
70
- if func is None :
71
- continue
64
+ if func is None : continue
65
+ if set ( op . output ) & set ( g . outputs ): continue
72
66
try :
73
67
inputs = []
74
68
for node in op .inputs :
@@ -109,18 +103,14 @@ def rewrite_constant_fold(g, ops):
109
103
old_node_name = op .name
110
104
logger .debug ("create const node [%s] replacing [%s]" , new_node_name , old_node_name )
111
105
ops [idx ] = g .make_const (new_node_name , val )
112
- ref_cnt_per_node [new_node_name ] = ref_cnt_per_node [old_node_name ]
113
106
114
107
logger .debug ("replace old output [%s] with new output [%s]" , old_output_name , new_output_name )
115
108
# need to re-write the consumers input name to use the const name
116
109
consumers = g .find_output_consumers (old_output_name )
117
110
if consumers :
118
111
for consumer in consumers :
119
112
g .replace_input (consumer , old_output_name , new_output_name )
120
- for node in op .inputs :
121
- ref_cnt_per_node [node .name ] -= 1
122
- if ref_cnt_per_node [node .name ] == 0 :
123
- g .remove_node (node .name )
113
+
124
114
# keep looking until there is nothing we can fold.
125
115
# We keep the graph in topological order so if we folded,
126
116
# the result might help a following op.
@@ -459,8 +449,8 @@ def compat_handler(ctx, node, **kwargs):
459
449
460
450
# pre-processing graph rewrites
461
451
# bi-directional re-writer should be placed after single directional re-writer
462
- rewriters = [rewrite_quantize_and_dequantize , rewrite_transpose , rewrite_flatten , rewrite_gemm ,
463
- rewrite_random_uniform , rewrite_random_uniform_fold_const ,
452
+ rewriters = [rewrite_constant_fold , rewrite_quantize_and_dequantize , rewrite_transpose , rewrite_flatten ,
453
+ rewrite_gemm , rewrite_random_uniform , rewrite_random_uniform_fold_const ,
464
454
rewrite_random_normal , rewrite_dropout , rewrite_eye ,
465
455
rewrite_leakyrelu , rewrite_thresholded_relu , rewrite_conv2d_with_pad ,
466
456
rewrite_single_direction_lstm , rewrite_bi_direction_lstm ,
0 commit comments