1
1
from paddle .v2 .fluid import framework as framework
2
2
from . import core
3
3
import collections
4
+ import copy
4
5
5
- __all__ = ['append_backward' ]
6
+ __all__ = ['append_backward' , 'calc_gradient' ]
6
7
7
8
8
9
def _rename_arg_ (op_descs , old_name , new_name , begin_idx = None , end_idx = None ):
@@ -65,6 +66,18 @@ def _all_in_set_(cands, s):
65
66
return True
66
67
67
68
69
+ def _some_in_set_ (cands , s ):
70
+ """
71
+ Test if some elements of 'cands' are in set 's'
72
+ """
73
+ if len (cands ) == 0 :
74
+ return False
75
+ for c in cands :
76
+ if c in s :
77
+ return True
78
+ return False
79
+
80
+
68
81
def _strip_grad_suffix_ (name ):
69
82
"""
70
83
Strip the grad suffix from the given varibale name
@@ -169,8 +182,8 @@ def _op_can_be_removed_(op_desc, no_grad_set):
169
182
return op_descs
170
183
171
184
172
- def _append_backward_ops_ (target ,
173
- block ,
185
+ def _append_backward_ops_ (block ,
186
+ ops ,
174
187
target_block ,
175
188
no_grad_dict ,
176
189
grad_to_var ,
@@ -179,8 +192,8 @@ def _append_backward_ops_(target,
179
192
Create all grad ops, and insert them into given block
180
193
181
194
Args:
182
- target(Variable): the target variable of forward pass
183
195
block(Block): the block where forward ops are
196
+ ops(Op): the forward operators whose backward ops need to be added
184
197
target_block(Block): the block which is going to hold new generated grad ops
185
198
no_grad_dict(dict):
186
199
key(int) block index
@@ -202,14 +215,14 @@ def empty_callback(block, context):
202
215
# grad_op_descs holds created grad_op, and will be appended to target_block
203
216
grad_op_descs = []
204
217
program = block .program
205
- for op in reversed (block . ops ):
218
+ for op in reversed (ops ):
206
219
grad_sub_block_list = []
207
220
# If the op has its own sub-block, deal with the sub-block first
208
221
if op .has_attr ("sub_block" ):
209
222
sub_block = program .block (op .block_attr ("sub_block" ))
210
223
grad_sub_block = program .create_block (parent_idx = sub_block .idx )
211
- _append_backward_ops_ (target , sub_block , grad_sub_block ,
212
- no_grad_dict , grad_to_var , callback )
224
+ _append_backward_ops_ (sub_block , sub_block . ops , grad_sub_block ,
225
+ no_grad_dict , grad_to_var )
213
226
grad_sub_block_list .append (grad_sub_block .desc )
214
227
215
228
# Getting op's corresponding grad_op
@@ -224,14 +237,6 @@ def empty_callback(block, context):
224
237
grad_op_descs = _remove_no_grad_branch_ (grad_op_descs ,
225
238
no_grad_dict [block .idx ])
226
239
227
- if target_block .idx == 0 :
228
- grad_op_descs .insert (
229
- 0 ,
230
- _create_op_desc_ ("fill_constant" , {}, {
231
- "Out" : [_append_grad_suffix_ (target .name )]
232
- }, {"shape" : [1 ],
233
- "value" : 1.0 ,
234
- "dtype" : target .dtype }))
235
240
# append op_desc in grad_op_descs to target_block
236
241
for op_desc in grad_op_descs :
237
242
new_op_desc = target_block .desc .append_op ()
@@ -252,7 +257,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
252
257
In most cases, this dict is generated by _append_backward_ops_()
253
258
grad_info_map(dict)(output argument):
254
259
key(str): forward variable name
255
- val(tuple): a tuple of (str, int ), str is the corresponding grad name, int is the block index
260
+ val(tuple): a tuple of (str, Block ), str is the corresponding grad name, Block is the block containing grad variable
256
261
"""
257
262
for op_idx in range (start_op_idx , block .desc .op_size ()):
258
263
op_desc = block .desc .op (op_idx )
@@ -279,41 +284,63 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
279
284
_infer_var_data_type_ (arg , block )
280
285
281
286
287
+ def _rename_grad_ (block , start_op_idx , grad_to_var , target_grad_map ):
288
+ var_map = copy .copy (target_grad_map )
289
+ for op_idx in range (start_op_idx , block .desc .op_size ()):
290
+ op_desc = block .desc .op (op_idx )
291
+ for name in op_desc .input_arg_names ():
292
+ if name in var_map :
293
+ op_desc .rename_input (name , var_map [name ])
294
+
295
+ for name in op_desc .output_arg_names ():
296
+ if block .desc .find_var (name .encode ("ascii" )):
297
+ new_name = "%s_%s" % (name , core .unique_integer (name ))
298
+ op_desc .rename_output (name , new_name )
299
+ var_map [name ] = new_name
300
+
301
+ for g , ng in var_map .iteritems ():
302
+ if g in grad_to_var :
303
+ grad_to_var [ng ] = grad_to_var [g ]
304
+ grad_to_var .pop (g )
305
+
306
+
307
+ def _get_stop_gradients_ (program ):
308
+ no_grad_dict = dict ()
309
+ assert isinstance (program , framework .Program )
310
+ for block in program .blocks :
311
+ assert isinstance (block , framework .Block )
312
+ block_no_grad_set = set ()
313
+ for var in block .vars .itervalues ():
314
+ assert isinstance (var , framework .Variable )
315
+ if var .stop_gradient :
316
+ block_no_grad_set .add (_append_grad_suffix_ (var .name ))
317
+ no_grad_dict [block .idx ] = block_no_grad_set
318
+ return no_grad_dict
319
+
320
+
282
321
def append_backward (loss , parameter_list = None , no_grad_set = None , callback = None ):
283
322
"""
284
323
Append backward part to main_program
285
324
286
325
Args:
287
326
loss(Variable): The variable generated by cost function.
288
- parameter_list(list): Parameters that need to be updated by optimizer.
289
- If None, it means all parameters need to be updated.
327
+ parameter_list(list[string] ): Parameters that need to be updated by
328
+ optimizer. If None, it means all parameters need to be updated.
290
329
no_grad_set(set): Variables that have no gradients in Block 0.
291
- If None, the set will be generated inside the function and
292
- contains all variables with `step_gradient=True` from all blocks .
330
+ All variables with `step_gradient=True` from all blocks will be
331
+ automatically added .
293
332
294
333
Return:
295
- (list[Variable]): list of (parameters, gradients ) pair.
334
+ (list[( Variable,Variable) ]): list of (parameter, gradient ) pair.
296
335
"""
297
336
assert isinstance (loss , framework .Variable )
298
337
299
338
program = loss .block .program
300
- no_grad_dict = dict ()
301
339
if no_grad_set is None :
302
- assert isinstance (program , framework .Program )
303
- for block in program .blocks :
304
- assert isinstance (block , framework .Block )
305
- block_no_grad_set = set ()
306
- for var in block .vars .itervalues ():
307
- assert isinstance (var , framework .Variable )
308
- if var .stop_gradient :
309
- block_no_grad_set .add (_append_grad_suffix_ (var .name ))
310
- no_grad_dict [block .idx ] = block_no_grad_set
311
- elif isinstance (no_grad_set , set ):
312
- no_grad_dict = {
313
- 0 : set ([_append_grad_suffix_ (name ) for name in no_grad_set ])
314
- }
315
- else :
316
- raise ValueError ("'no_grad_set' should be a set or None." )
340
+ no_grad_set = set ()
341
+ no_grad_set = copy .copy (no_grad_set )
342
+ no_grad_dict = _get_stop_gradients_ (program )
343
+ no_grad_dict [0 ].update (map (_append_grad_suffix_ , no_grad_set ))
317
344
318
345
grad_info_map = dict ()
319
346
root_block = program .block (0 )
@@ -322,8 +349,25 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
322
349
current_block_idx = program .current_block_idx
323
350
grad_to_var = dict ()
324
351
325
- _append_backward_ops_ (loss , root_block , root_block , no_grad_dict ,
352
+ op_desc = _create_op_desc_ ("fill_constant" , {}, {
353
+ "Out" : [_append_grad_suffix_ (loss .name )]
354
+ }, {"shape" : [1 ],
355
+ "value" : 1.0 ,
356
+ "dtype" : loss .dtype })
357
+ root_block .desc .append_op ().copy_from (op_desc )
358
+
359
+ block_no_grad_set = set (map (_strip_grad_suffix_ , no_grad_dict [0 ]))
360
+ op_path = _find_op_path_ (root_block , [loss ], [], block_no_grad_set )
361
+ no_grad_dict [0 ].update (map (_append_grad_suffix_ , block_no_grad_set ))
362
+
363
+ _append_backward_ops_ (root_block , op_path , root_block , no_grad_dict ,
326
364
grad_to_var , callback )
365
+
366
+ # Because calc_gradient may be called multiple times,
367
+ # we need rename the internal gradient variables so that they have
368
+ # different names.
369
+ _rename_grad_ (root_block , fwd_op_num , grad_to_var , {})
370
+
327
371
_append_backward_vars_ (root_block , fwd_op_num , grad_to_var , grad_info_map )
328
372
329
373
program .current_block_idx = current_block_idx
@@ -334,6 +378,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
334
378
else :
335
379
params = program .global_block ().all_parameters ()
336
380
parameters = [param .name for param in params ]
381
+
337
382
params_and_grads = []
338
383
for param in parameters :
339
384
if param not in grad_info_map :
@@ -351,3 +396,147 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
351
396
else :
352
397
params_and_grads .append ((param_var , None ))
353
398
return params_and_grads
399
+
400
+
401
+ def _as_list (x ):
402
+ if x is None :
403
+ return []
404
+ return list (x ) if isinstance (x , collections .Sequence ) else [x ]
405
+
406
+
407
+ def _find_op_path_ (block , outputs , inputs , no_grad_set ):
408
+ """
409
+ no_grad_set will also be changed
410
+ """
411
+ input_names = set ([inp .name for inp in inputs ])
412
+ output_names = set ([out .name for out in outputs ])
413
+
414
+ relevant_op_flags = [True ] * len (block .ops )
415
+
416
+ # All the inputs of the block are used if inputs is empty,
417
+ if inputs :
418
+ for i , op in enumerate (block .ops ):
419
+ if _some_in_set_ (op .desc .input_arg_names (), input_names ):
420
+ for name in op .desc .output_arg_names ():
421
+ if name not in no_grad_set :
422
+ input_names .add (name )
423
+ else :
424
+ relevant_op_flags [i ] = False
425
+
426
+ for i , op in reversed (list (enumerate (block .ops ))):
427
+ if _some_in_set_ (op .desc .output_arg_names (), output_names ):
428
+ for name in op .desc .input_arg_names ():
429
+ if name not in no_grad_set :
430
+ output_names .add (name )
431
+ else :
432
+ relevant_op_flags [i ] = False
433
+
434
+ op_path = [
435
+ block .ops [i ] for i in range (len (block .ops )) if relevant_op_flags [i ]
436
+ ]
437
+
438
+ if inputs :
439
+ for op in op_path :
440
+ for name in op .desc .input_arg_names ():
441
+ if name not in input_names :
442
+ no_grad_set .add (name )
443
+
444
+ return op_path
445
+
446
+
447
+ def calc_gradient (targets , inputs , target_gradients = None , no_grad_set = None ):
448
+ """
449
+ Backpropagate the graidents of targets to inputs.
450
+
451
+ Args:
452
+ targets(Variable|list[Variable]): The target variables
453
+ inputs(Variable|list[Variable]): The input variables
454
+ no_grad_set(set[string]): The names of variables that have no gradients
455
+ in Block 0. All variables with `stop_gradient=True` from all blocks
456
+ will be automatically added.
457
+
458
+ Return:
459
+ (list[Variable]): list of gradients for inputs
460
+ If an input does not affect targets, the corresponding gradient variable
461
+ will be None
462
+ """
463
+ targets = _as_list (targets )
464
+ inputs = _as_list (inputs )
465
+ target_gradients = _as_list (target_gradients )
466
+
467
+ block = targets [0 ].block
468
+ prog = block .program
469
+ block_idx = block .idx
470
+
471
+ if not target_gradients :
472
+ target_gradients = [None ] * len (targets )
473
+
474
+ if len (targets ) != len (target_gradients ):
475
+ raise ValueError (
476
+ "Should have the same number of target_gradients as targets" )
477
+
478
+ if no_grad_set is None :
479
+ no_grad_set = set ()
480
+ no_grad_set = copy .copy (no_grad_set )
481
+ no_grad_dict = _get_stop_gradients_ (prog )
482
+ no_grad_dict [0 ].update (map (_append_grad_suffix_ , no_grad_set ))
483
+
484
+ fwd_op_num = block .desc .op_size ()
485
+
486
+ target_grad_map = {}
487
+ for i , grad in enumerate (target_gradients ):
488
+ target = targets [i ]
489
+ if grad is None :
490
+ grad_name = _append_grad_suffix_ (target .name )
491
+ op_desc = _create_op_desc_ ("fill_constant_batch_size_like" ,
492
+ {"Input" : [target .name ]},
493
+ {"Out" : [grad_name ]}, {
494
+ "shape" : target .shape ,
495
+ "value" : 1.0 ,
496
+ "dtype" : target .dtype ,
497
+ 'input_dim_idx' : 0 ,
498
+ 'output_dim_idx' : 0
499
+ })
500
+ block .desc .append_op ().copy_from (op_desc )
501
+ else :
502
+ if target .block .idx != block_idx or target .block .program != prog :
503
+ raise ValueError ("all targets must be in the same block" )
504
+ if target .shape != grad .shape :
505
+ raise ValueError (
506
+ "The shapes of target and grad are different: %s %s" % (
507
+ target .name , grad .name ))
508
+ target_grad_map [_append_grad_suffix_ (target .name )] = grad .name
509
+
510
+ for input in inputs :
511
+ if input .block .program != prog :
512
+ raise "input must be in the same program as targets"
513
+
514
+ block_no_grad_set = set (map (_strip_grad_suffix_ , no_grad_dict [0 ]))
515
+ op_path = _find_op_path_ (block , targets , inputs , block_no_grad_set )
516
+ no_grad_dict [0 ].update (map (_append_grad_suffix_ , block_no_grad_set ))
517
+ grad_to_var = dict ()
518
+ grad_info_map = dict ()
519
+ _append_backward_ops_ (block , op_path , block , no_grad_dict , grad_to_var )
520
+
521
+ # Because calc_gradient may be called multiple times,
522
+ # we need rename the internal gradient variables so that they have
523
+ # different names.
524
+ _rename_grad_ (block , fwd_op_num , grad_to_var , target_grad_map )
525
+
526
+ _append_backward_vars_ (block , fwd_op_num , grad_to_var , grad_info_map )
527
+ prog .sync_with_cpp ()
528
+
529
+ grad_vars = []
530
+ for input_var in inputs :
531
+ if input_var .name not in grad_info_map :
532
+ grad_vars .append (None )
533
+ else :
534
+ grad_info = grad_info_map [input_var .name ]
535
+ grad_block = grad_info [1 ]
536
+ grad_var = grad_block .var (grad_info [0 ])
537
+ grad_vars .append (grad_var )
538
+
539
+ if len (grad_vars ) == 1 :
540
+ return grad_vars [0 ]
541
+ else :
542
+ return grad_vars
0 commit comments