@@ -247,6 +247,125 @@ def _op_can_be_removed_(op_desc, no_grad_set):
247
247
return op_descs
248
248
249
249
250
+ def _find_not_need_ops (grad_op_descs , forward_ops , input_grad_names_set ):
251
+ """
252
+ Pruning Program with Structural Analysis Method of Computational Graph.
253
+ The nodes of the computational graph composed of backward OPS should be
254
+ interconnected. If there are unconnected sub-graphs in the computational graph,
255
+ these sub-graphs should be cut off.
256
+
257
+ Args:
258
+ grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs.
259
+ forward_ops(list[Operator]): The forward ops.
260
+ input_grad_names_set(set): this set is used to store the gradients' name
261
+ which is generated by backward ops, and input_grad_names_set can help
262
+ to prune the unnecessary backward ops.
263
+
264
+ Return:
265
+ (list[core.OpDesc]): A list of OpDescs which should be pruned.
266
+ """
267
+
268
+ class Var (object ):
269
+ def __init__ (self , var_name ):
270
+ self .var_name = var_name
271
+ self .gen_op = None
272
+ self .pendding_ops = []
273
+
274
+ def set_gen_op (self , gen_op ):
275
+ assert isinstance (gen_op , Op )
276
+ assert self .gen_op is None
277
+ self .gen_op = gen_op
278
+
279
+ def add_pending_op (self , op ):
280
+ assert isinstance (op , Op )
281
+ self .pendding_ops .append (op )
282
+
283
+ class Op (object ):
284
+ def __init__ (self , op_desc ):
285
+ self .op_desc = op_desc
286
+ self .inputs = []
287
+ self .outputs = []
288
+
289
+ def insert_input (self , var ):
290
+ assert isinstance (var , Var )
291
+ self .inputs .append (var )
292
+
293
+ def insert_output (self , var ):
294
+ assert isinstance (var , Var )
295
+ self .outputs .append (var )
296
+
297
+ var_versions = dict ()
298
+
299
+ def _create_node (name ):
300
+ if name not in var_versions .keys ():
301
+ var_versions [name ] = [Var (name )]
302
+ else :
303
+ var_versions [name ].append (Var (name ))
304
+ return var_versions [name ][- 1 ]
305
+
306
+ def _create_or_get_last_version_node (name ):
307
+ if name not in var_versions .keys ():
308
+ var_versions [name ] = [Var (name )]
309
+ return var_versions [name ][- 1 ]
310
+
311
+ def _create_op_node (op_desc ):
312
+ op_node = Op (op_desc )
313
+ for input in op_desc .input_arg_names ():
314
+ var = _create_or_get_last_version_node (name = input )
315
+ var .add_pending_op (op_node )
316
+ op_node .insert_input (var )
317
+ for output in op_desc .output_arg_names ():
318
+ var = _create_node (name = output )
319
+ var .set_gen_op (op_node )
320
+ op_node .insert_output (var )
321
+ return op_node
322
+
323
+ # Record the forward vars
324
+ forward_vars_set = set () if input_grad_names_set is None else set (
325
+ input_grad_names_set )
326
+ for op in forward_ops :
327
+ forward_vars_set .update (op .desc .input_arg_names ())
328
+ forward_vars_set .update (op .desc .output_arg_names ())
329
+
330
+ # Record the vars which are created during backward and is not generated by op.
331
+ backward_vars_set = set ()
332
+ # special_op_nodes is the candidate sub-graph head node.
333
+ special_op_nodes = set ()
334
+ for op_desc in grad_op_descs :
335
+ input_set = set (op_desc .input_arg_names ())
336
+ # The new_vars are created during backward and is not generated by op.
337
+ new_vars = input_set - forward_vars_set - backward_vars_set
338
+ backward_vars_set .update (op_desc .output_arg_names ())
339
+
340
+ op_node = _create_op_node (op_desc )
341
+ if len (new_vars ) == len (input_set ):
342
+ special_op_nodes .add (op_node )
343
+
344
+ not_need_op_descs = []
345
+ # Start traversing all candidate sub-graph headers to check whether
346
+ # they are connected to backward computational graphs, and if they are
347
+ # not, list them in not_need_op_descs
348
+ for special_op_node in special_op_nodes :
349
+ op_list = [special_op_node ]
350
+ ready_vars = set (special_op_node .inputs )
351
+ remove_ops = True
352
+ candidate_ops = [special_op_node ]
353
+ while len (candidate_ops ) > 0 :
354
+ op_node = candidate_ops .pop (0 )
355
+ if _all_in_set_ (op_node .inputs , ready_vars ):
356
+ for out_var in op_node .outputs :
357
+ candidate_ops .extend (out_var .pendding_ops )
358
+ op_list .extend (out_var .pendding_ops )
359
+ ready_vars .update (op_node .outputs )
360
+ else :
361
+ remove_ops = False
362
+ break
363
+ if remove_ops :
364
+ not_need_op_descs .extend ([node .op_desc for node in op_list ])
365
+
366
+ return set (not_need_op_descs )
367
+
368
+
250
369
from .proto import framework_pb2
251
370
252
371
@@ -276,7 +395,10 @@ def _append_backward_ops_(block,
276
395
grad_to_var(dict)(output argument):
277
396
key(str): grad variable name
278
397
val(str): corresponding forward variable name
279
- callback(callable object): a callable object used to decorate new generated grad ops
398
+ callbacks(callable object): a callable object used to decorate new generated grad ops
399
+ input_grad_names_set(set): this set is used to store the gradients' name which is
400
+ generated by backward ops, and input_grad_names_set can help to prune the unnecessary
401
+ backward ops.
280
402
"""
281
403
if callbacks is not None :
282
404
assert (isinstance (callbacks , list ))
@@ -342,6 +464,10 @@ def _append_backward_ops_(block,
342
464
grad_op_descs = _remove_no_grad_branch_ (grad_op_descs ,
343
465
no_grad_dict [block .idx ])
344
466
467
+ not_need_ops = _find_not_need_ops (grad_op_descs , ops , input_grad_names_set )
468
+ grad_op_descs = [
469
+ op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
470
+ ]
345
471
# append op_desc in grad_op_descs to target_block
346
472
op_role_attr_name = core .op_proto_and_checker_maker .kOpRoleAttrName ()
347
473
backward = core .op_proto_and_checker_maker .OpRole .Backward
@@ -552,7 +678,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
552
678
553
679
block_no_grad_set = set (map (_strip_grad_suffix_ , no_grad_dict [0 ]))
554
680
op_path = _find_op_path_ (root_block , [loss ], [], block_no_grad_set )
555
-
681
+ no_grad_vars = _find_no_grad_vars (root_block , op_path , [loss ],
682
+ block_no_grad_set )
683
+ block_no_grad_set .update (no_grad_vars )
556
684
no_grad_dict [0 ].update (list (map (_append_grad_suffix_ , block_no_grad_set )))
557
685
558
686
input_grad_names_set = None
@@ -630,6 +758,26 @@ def _as_list(x):
630
758
return list (x ) if isinstance (x , collections .Sequence ) else [x ]
631
759
632
760
761
+ def _find_no_grad_vars (block , op_path , targets , no_grad_set ):
762
+ """
763
+ Find the vars which is not used in the program, and
764
+ those var belong to no_grad_var.
765
+ """
766
+ output_names = set ([out .name for out in targets ])
767
+ no_grad_var = []
768
+ for i , op in reversed (list (enumerate (op_path ))):
769
+ # If the op has sub_block, it is too complicated to find the correct no_grad_var.
770
+ if not op .has_attr ("sub_block" ):
771
+ for out_var in op .desc .output_arg_names ():
772
+ if out_var not in output_names and out_var not in op .desc .input_arg_names (
773
+ ) and not block .vars [out_var ].stop_gradient :
774
+ no_grad_var .append (out_var )
775
+ for name in op .desc .input_arg_names ():
776
+ if name not in no_grad_set :
777
+ output_names .add (name )
778
+ return set (no_grad_var )
779
+
780
+
633
781
def _find_op_path_ (block , outputs , inputs , no_grad_set ):
634
782
"""
635
783
no_grad_set will also be changed
0 commit comments