Skip to content

Commit 585dec3

Browse files
author
xuwei06
committed
Calculating gradients for partial graph
Added backward.calc_gradient to backpropagate gradient from given targets to inputs.
1 parent 0ef9dc6 commit 585dec3

File tree

5 files changed

+275
-43
lines changed

5 files changed

+275
-43
lines changed

paddle/framework/grad_op_desc_maker.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ class GradOpDescMakerBase {
8787
auto onames = this->Output(name);
8888
ret_val.reserve(onames.size());
8989
std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
90-
GradVarName);
90+
[this](const std::string& fwd_var_name) -> std::string {
91+
auto g_name = GradVarName(fwd_var_name);
92+
(*this->grad_to_var_)[g_name] = fwd_var_name;
93+
return g_name;
94+
});
9195
return ret_val;
9296
}
9397

paddle/framework/op_desc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class OpDesc {
129129
}
130130

131131
proto::OpDesc desc_;
132-
// input arg name => output variable names
132+
// input arg name => input variable names
133133
VariableNameMap inputs_;
134134
// output arg name => output variable names
135135
VariableNameMap outputs_;

python/paddle/v2/fluid/backward.py

Lines changed: 227 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from paddle.v2.fluid import framework as framework
22
from . import core
33
import collections
4+
import copy
45

5-
__all__ = ['append_backward']
6+
__all__ = ['append_backward', 'calc_gradient']
67

78

89
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
@@ -65,6 +66,18 @@ def _all_in_set_(cands, s):
6566
return True
6667

6768

69+
def _some_in_set_(cands, s):
70+
"""
71+
Test if some elements of 'cands' are in set 's'
72+
"""
73+
if len(cands) == 0:
74+
return False
75+
for c in cands:
76+
if c in s:
77+
return True
78+
return False
79+
80+
6881
def _strip_grad_suffix_(name):
6982
"""
7083
Strip the grad suffix from the given varibale name
@@ -169,8 +182,8 @@ def _op_can_be_removed_(op_desc, no_grad_set):
169182
return op_descs
170183

171184

172-
def _append_backward_ops_(target,
173-
block,
185+
def _append_backward_ops_(block,
186+
ops,
174187
target_block,
175188
no_grad_dict,
176189
grad_to_var,
@@ -179,8 +192,8 @@ def _append_backward_ops_(target,
179192
Create all grad ops, and insert them into given block
180193
181194
Args:
182-
target(Variable): the target variable of forward pass
183195
block(Block): the block where forward ops are
196+
ops(Op): the forward operators whose backward ops need to be added
184197
target_block(Block): the block which is going to hold new generated grad ops
185198
no_grad_dict(dict):
186199
key(int) block index
@@ -202,14 +215,14 @@ def empty_callback(block, context):
202215
# grad_op_descs holds created grad_op, and will be appended to target_block
203216
grad_op_descs = []
204217
program = block.program
205-
for op in reversed(block.ops):
218+
for op in reversed(ops):
206219
grad_sub_block_list = []
207220
# If the op has its own sub-block, deal with the sub-block first
208221
if op.has_attr("sub_block"):
209222
sub_block = program.block(op.block_attr("sub_block"))
210223
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
211-
_append_backward_ops_(target, sub_block, grad_sub_block,
212-
no_grad_dict, grad_to_var, callback)
224+
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
225+
no_grad_dict, grad_to_var)
213226
grad_sub_block_list.append(grad_sub_block.desc)
214227

215228
# Getting op's corresponding grad_op
@@ -224,14 +237,6 @@ def empty_callback(block, context):
224237
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
225238
no_grad_dict[block.idx])
226239

227-
if target_block.idx == 0:
228-
grad_op_descs.insert(
229-
0,
230-
_create_op_desc_("fill_constant", {}, {
231-
"Out": [_append_grad_suffix_(target.name)]
232-
}, {"shape": [1],
233-
"value": 1.0,
234-
"dtype": target.dtype}))
235240
# append op_desc in grad_op_descs to target_block
236241
for op_desc in grad_op_descs:
237242
new_op_desc = target_block.desc.append_op()
@@ -252,7 +257,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
252257
In most cases, this dict is generated by _append_backward_ops_()
253258
grad_info_map(dict)(output argument):
254259
key(str): forward variable name
255-
val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
260+
val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable
256261
"""
257262
for op_idx in range(start_op_idx, block.desc.op_size()):
258263
op_desc = block.desc.op(op_idx)
@@ -279,41 +284,63 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
279284
_infer_var_data_type_(arg, block)
280285

281286

287+
def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
288+
var_map = copy.copy(target_grad_map)
289+
for op_idx in range(start_op_idx, block.desc.op_size()):
290+
op_desc = block.desc.op(op_idx)
291+
for name in op_desc.input_arg_names():
292+
if name in var_map:
293+
op_desc.rename_input(name, var_map[name])
294+
295+
for name in op_desc.output_arg_names():
296+
if block.desc.find_var(name.encode("ascii")):
297+
new_name = "%s_%s" % (name, core.unique_integer(name))
298+
op_desc.rename_output(name, new_name)
299+
var_map[name] = new_name
300+
301+
for g, ng in var_map.iteritems():
302+
if g in grad_to_var:
303+
grad_to_var[ng] = grad_to_var[g]
304+
grad_to_var.pop(g)
305+
306+
307+
def _get_stop_gradients_(program):
308+
no_grad_dict = dict()
309+
assert isinstance(program, framework.Program)
310+
for block in program.blocks:
311+
assert isinstance(block, framework.Block)
312+
block_no_grad_set = set()
313+
for var in block.vars.itervalues():
314+
assert isinstance(var, framework.Variable)
315+
if var.stop_gradient:
316+
block_no_grad_set.add(_append_grad_suffix_(var.name))
317+
no_grad_dict[block.idx] = block_no_grad_set
318+
return no_grad_dict
319+
320+
282321
def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
283322
"""
284323
Append backward part to main_program
285324
286325
Args:
287326
loss(Variable): The variable generated by cost function.
288-
parameter_list(list): Parameters that need to be updated by optimizer.
289-
If None, it means all parameters need to be updated.
327+
parameter_list(list[string]): Parameters that need to be updated by
328+
optimizer. If None, it means all parameters need to be updated.
290329
no_grad_set(set): Variables that have no gradients in Block 0.
291-
If None, the set will be generated inside the function and
292-
contains all variables with `step_gradient=True` from all blocks.
330+
All variables with `step_gradient=True` from all blocks will be
331+
automatically added.
293332
294333
Return:
295-
(list[Variable]): list of (parameters, gradients) pair.
334+
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
296335
"""
297336
assert isinstance(loss, framework.Variable)
298337

299338
program = loss.block.program
300-
no_grad_dict = dict()
301339
if no_grad_set is None:
302-
assert isinstance(program, framework.Program)
303-
for block in program.blocks:
304-
assert isinstance(block, framework.Block)
305-
block_no_grad_set = set()
306-
for var in block.vars.itervalues():
307-
assert isinstance(var, framework.Variable)
308-
if var.stop_gradient:
309-
block_no_grad_set.add(_append_grad_suffix_(var.name))
310-
no_grad_dict[block.idx] = block_no_grad_set
311-
elif isinstance(no_grad_set, set):
312-
no_grad_dict = {
313-
0: set([_append_grad_suffix_(name) for name in no_grad_set])
314-
}
315-
else:
316-
raise ValueError("'no_grad_set' should be a set or None.")
340+
no_grad_set = set()
341+
no_grad_set = copy.copy(no_grad_set)
342+
no_grad_dict = _get_stop_gradients_(program)
343+
no_grad_dict[0].update(map(_append_grad_suffix_, no_grad_set))
317344

318345
grad_info_map = dict()
319346
root_block = program.block(0)
@@ -322,8 +349,25 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
322349
current_block_idx = program.current_block_idx
323350
grad_to_var = dict()
324351

325-
_append_backward_ops_(loss, root_block, root_block, no_grad_dict,
352+
op_desc = _create_op_desc_("fill_constant", {}, {
353+
"Out": [_append_grad_suffix_(loss.name)]
354+
}, {"shape": [1],
355+
"value": 1.0,
356+
"dtype": loss.dtype})
357+
root_block.desc.append_op().copy_from(op_desc)
358+
359+
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
360+
op_path = _find_op_path_(root_block, [loss], [], block_no_grad_set)
361+
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))
362+
363+
_append_backward_ops_(root_block, op_path, root_block, no_grad_dict,
326364
grad_to_var, callback)
365+
366+
# Because calc_gradient may be called multiple times,
367+
# we need rename the internal gradient variables so that they have
368+
# different names.
369+
_rename_grad_(root_block, fwd_op_num, grad_to_var, {})
370+
327371
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
328372

329373
program.current_block_idx = current_block_idx
@@ -334,6 +378,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
334378
else:
335379
params = program.global_block().all_parameters()
336380
parameters = [param.name for param in params]
381+
337382
params_and_grads = []
338383
for param in parameters:
339384
if param not in grad_info_map:
@@ -351,3 +396,147 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
351396
else:
352397
params_and_grads.append((param_var, None))
353398
return params_and_grads
399+
400+
401+
def _as_list(x):
402+
if x is None:
403+
return []
404+
return list(x) if isinstance(x, collections.Sequence) else [x]
405+
406+
407+
def _find_op_path_(block, outputs, inputs, no_grad_set):
408+
"""
409+
no_grad_set will also be changed
410+
"""
411+
input_names = set([inp.name for inp in inputs])
412+
output_names = set([out.name for out in outputs])
413+
414+
relevant_op_flags = [True] * len(block.ops)
415+
416+
# All the inputs of the block are used if inputs is empty,
417+
if inputs:
418+
for i, op in enumerate(block.ops):
419+
if _some_in_set_(op.desc.input_arg_names(), input_names):
420+
for name in op.desc.output_arg_names():
421+
if name not in no_grad_set:
422+
input_names.add(name)
423+
else:
424+
relevant_op_flags[i] = False
425+
426+
for i, op in reversed(list(enumerate(block.ops))):
427+
if _some_in_set_(op.desc.output_arg_names(), output_names):
428+
for name in op.desc.input_arg_names():
429+
if name not in no_grad_set:
430+
output_names.add(name)
431+
else:
432+
relevant_op_flags[i] = False
433+
434+
op_path = [
435+
block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i]
436+
]
437+
438+
if inputs:
439+
for op in op_path:
440+
for name in op.desc.input_arg_names():
441+
if name not in input_names:
442+
no_grad_set.add(name)
443+
444+
return op_path
445+
446+
447+
def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
448+
"""
449+
Backpropagate the graidents of targets to inputs.
450+
451+
Args:
452+
targets(Variable|list[Variable]): The target variables
453+
inputs(Variable|list[Variable]): The input variables
454+
no_grad_set(set[string]): The names of variables that have no gradients
455+
in Block 0. All variables with `stop_gradient=True` from all blocks
456+
will be automatically added.
457+
458+
Return:
459+
(list[Variable]): list of gradients for inputs
460+
If an input does not affect targets, the corresponding gradient variable
461+
will be None
462+
"""
463+
targets = _as_list(targets)
464+
inputs = _as_list(inputs)
465+
target_gradients = _as_list(target_gradients)
466+
467+
block = targets[0].block
468+
prog = block.program
469+
block_idx = block.idx
470+
471+
if not target_gradients:
472+
target_gradients = [None] * len(targets)
473+
474+
if len(targets) != len(target_gradients):
475+
raise ValueError(
476+
"Should have the same number of target_gradients as targets")
477+
478+
if no_grad_set is None:
479+
no_grad_set = set()
480+
no_grad_set = copy.copy(no_grad_set)
481+
no_grad_dict = _get_stop_gradients_(prog)
482+
no_grad_dict[0].update(map(_append_grad_suffix_, no_grad_set))
483+
484+
fwd_op_num = block.desc.op_size()
485+
486+
target_grad_map = {}
487+
for i, grad in enumerate(target_gradients):
488+
target = targets[i]
489+
if grad is None:
490+
grad_name = _append_grad_suffix_(target.name)
491+
op_desc = _create_op_desc_("fill_constant_batch_size_like",
492+
{"Input": [target.name]},
493+
{"Out": [grad_name]}, {
494+
"shape": target.shape,
495+
"value": 1.0,
496+
"dtype": target.dtype,
497+
'input_dim_idx': 0,
498+
'output_dim_idx': 0
499+
})
500+
block.desc.append_op().copy_from(op_desc)
501+
else:
502+
if target.block.idx != block_idx or target.block.program != prog:
503+
raise ValueError("all targets must be in the same block")
504+
if target.shape != grad.shape:
505+
raise ValueError(
506+
"The shapes of target and grad are different: %s %s" % (
507+
target.name, grad.name))
508+
target_grad_map[_append_grad_suffix_(target.name)] = grad.name
509+
510+
for input in inputs:
511+
if input.block.program != prog:
512+
raise "input must be in the same program as targets"
513+
514+
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
515+
op_path = _find_op_path_(block, targets, inputs, block_no_grad_set)
516+
no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set))
517+
grad_to_var = dict()
518+
grad_info_map = dict()
519+
_append_backward_ops_(block, op_path, block, no_grad_dict, grad_to_var)
520+
521+
# Because calc_gradient may be called multiple times,
522+
# we need rename the internal gradient variables so that they have
523+
# different names.
524+
_rename_grad_(block, fwd_op_num, grad_to_var, target_grad_map)
525+
526+
_append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map)
527+
prog.sync_with_cpp()
528+
529+
grad_vars = []
530+
for input_var in inputs:
531+
if input_var.name not in grad_info_map:
532+
grad_vars.append(None)
533+
else:
534+
grad_info = grad_info_map[input_var.name]
535+
grad_block = grad_info[1]
536+
grad_var = grad_block.var(grad_info[0])
537+
grad_vars.append(grad_var)
538+
539+
if len(grad_vars) == 1:
540+
return grad_vars[0]
541+
else:
542+
return grad_vars

0 commit comments

Comments
 (0)