@@ -141,9 +141,8 @@ def rewrite_transpose(g, ops):
141
141
dims = [i for i in range (len (shape ) - 1 , - 1 , - 1 )]
142
142
output .set_attr ("perm" , dims )
143
143
g .remove_input (output , output .input [1 ])
144
- for n in set (match .get_nodes ()):
145
- if n != output :
146
- g .remove_node (n .name )
144
+ to_delete = [n for n in match .get_nodes () if n != output ]
145
+ g .safe_remove_nodes (to_delete )
147
146
return ops
148
147
149
148
@@ -175,8 +174,7 @@ def rewrite_random_normal(g, ops):
175
174
attr = {"shape" : shape , "mean" : mean , "scale" : 1.0 , "dtype" : dtype })
176
175
177
176
g .replace_all_inputs (ops , output .output [0 ], new_node .output [0 ])
178
- for n in set (match .get_nodes ()):
179
- g .remove_node (n .name )
177
+ g .safe_remove_nodes (match .get_nodes ())
180
178
return ops
181
179
182
180
@@ -208,8 +206,7 @@ def rewrite_dropout(g, ops):
208
206
dtypes = [g .get_dtype (inputs2 .input [0 ])]
209
207
)
210
208
g .replace_all_inputs (ops , outputs .output [0 ], new_node .output [0 ])
211
- for n in set (match .get_nodes ()):
212
- g .remove_node (n .name )
209
+ g .safe_remove_nodes (match .get_nodes ())
213
210
214
211
# remove dropout if its ratio is 1.0
215
212
for node in g .get_nodes ():
@@ -294,10 +291,8 @@ def rewrite_flatten(g, ops):
294
291
295
292
g .set_shape (out_name , input_shape [:- 2 ] + [new_dim ])
296
293
g .replace_all_inputs (ops , reshape_node .output [0 ], out_name )
297
-
298
- for n in set (match .get_nodes ()):
299
- if n != input_node :
300
- g .remove_node (n .name )
294
+ to_delete = [n for n in match .get_nodes () if n != input_node ]
295
+ g .safe_remove_nodes (to_delete )
301
296
302
297
return ops
303
298
@@ -654,6 +649,14 @@ def run_rewriters(g, funcs, continue_on_error):
654
649
else :
655
650
raise ex
656
651
652
+ if utils .is_debug_mode ():
653
+ broken_outputs = g .check_integrity ()
654
+ if broken_outputs :
655
+ logging .error (
656
+ "After rewriter %s, graph breaks at outputs %s" ,
657
+ func .__name__ , broken_outputs
658
+ )
659
+
657
660
if g .contained_graphs :
658
661
for dict_val in g .contained_graphs .values ():
659
662
for attr_name , b_g in dict_val .items ():
0 commit comments