Skip to content

Commit cc3ba76

Browse files
author
chengduo
authored
[Cherry pick] Fix backward error (#18835)
* fix backward bug
1 parent 46c5345 commit cc3ba76

File tree

8 files changed

+289
-26
lines changed

8 files changed

+289
-26
lines changed

paddle/fluid/framework/parallel_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
650650
"The number(%d) of samples of "
651651
"current batch is less than the count(%d) of "
652652
"devices(%s), currently, it is not allowed. ",
653-
lod_tensors.size(), lod_tensors.size(),
653+
lod_tensors.size(), member_->places_.size(),
654654
(is_cpu_place ? "CPU" : "GPU"));
655655
if (is_cpu_place) {
656656
error_info +=

paddle/fluid/op_use_default_grad_op_maker.spec

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ fusion_seqexpand_concat_fc
1515
fusion_seqpool_concat
1616
fusion_squared_mat_sub
1717
gru
18-
hierarchical_sigmoid
1918
lrn
2019
lstm_unit
2120
max_pool2d_with_index

paddle/fluid/operators/hierarchical_sigmoid_op.cc

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
8686
}
8787
};
8888

89+
/*
90+
* Inputs: X, W, Label, PathTable, PathCode, Bias
91+
* Outputs: Out, PreOut, W_out
92+
*/
8993
template <typename AttrType>
9094
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
9195
public:
@@ -162,6 +166,37 @@ Hierarchical Probabilistic Neural Network Language Model."
162166
}
163167
};
164168

169+
/*
170+
* Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD
171+
* Outputs: X@GRAD, W@GRAD, Bias@GRAD
172+
*/
173+
class HierarchicalSigmoidGradMaker : public framework::SingleGradOpDescMaker {
174+
public:
175+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
176+
177+
std::unique_ptr<framework::OpDesc> Apply() const override {
178+
auto* op = new framework::OpDesc();
179+
op->SetType(this->ForwardOpType() + "_grad");
180+
// Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD
181+
op->SetInput("X", Input("X"));
182+
op->SetInput("W", Input("W"));
183+
op->SetInput("Bias", Input("Bias"));
184+
op->SetInput("Label", Input("Label"));
185+
op->SetInput("PathTable", Input("PathTable"));
186+
op->SetInput("PathCode", Input("PathCode"));
187+
op->SetInput("PreOut", Output("PreOut"));
188+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
189+
190+
// Outputs: X@GRAD, W@GRAD, Bias@GRAD
191+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
192+
op->SetOutput(framework::GradVarName("W"), InputGrad("W"));
193+
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
194+
op->SetAttrMap(Attrs());
195+
196+
return std::unique_ptr<framework::OpDesc>(op);
197+
}
198+
};
199+
165200
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
166201
public:
167202
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -209,17 +244,17 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
209244
auto attr = ctx->GetAttr("is_sparse");
210245
bool is_sparse = boost::get<bool>(attr);
211246
if (is_sparse) {
212-
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
213-
<< " is set to SelectedRows";
247+
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
248+
<< " is set to SelectedRows";
214249
ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS);
215250
} else {
216-
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
217-
<< " is set to LoDTensor";
251+
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
252+
<< " is set to LoDTensor";
218253
ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR);
219254
}
220255
if (hasBias) {
221-
VLOG(30) << "hierarchical_sigmoid_grad op "
222-
<< framework::GradVarName("Bias") << " is set to LoDTensor";
256+
VLOG(3) << "hierarchical_sigmoid_grad op "
257+
<< framework::GradVarName("Bias") << " is set to LoDTensor";
223258
ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR);
224259
}
225260
ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0]));
@@ -232,7 +267,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
232267
namespace ops = paddle::operators;
233268
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
234269
ops::HierarchicalSigmoidOpMaker<int>,
235-
paddle::framework::DefaultGradOpDescMaker<true>);
270+
ops::HierarchicalSigmoidGradMaker);
236271
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
237272
ops::HierarchicalSigmoidGradOpGradVarTypeInference);
238273
REGISTER_OP_CPU_KERNEL(

paddle/fluid/operators/scatter_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,14 @@ class ScatterGradOp : public framework::OperatorWithKernel {
5858
using framework::OperatorWithKernel::OperatorWithKernel;
5959

6060
void InferShape(framework::InferShapeContext* ctx) const override {
61-
ctx->SetOutputDim(framework::GradVarName("Updates"),
62-
ctx->GetInputDim("Updates"));
63-
ctx->SetOutputDim(framework::GradVarName("X"),
64-
ctx->GetInputDim(framework::GradVarName("Out")));
61+
if (ctx->HasOutput(framework::GradVarName("Updates"))) {
62+
ctx->SetOutputDim(framework::GradVarName("Updates"),
63+
ctx->GetInputDim("Updates"));
64+
}
65+
if (ctx->HasOutput(framework::GradVarName("X"))) {
66+
ctx->SetOutputDim(framework::GradVarName("X"),
67+
ctx->GetInputDim(framework::GradVarName("Out")));
68+
}
6569
}
6670

6771
protected:

paddle/fluid/operators/scatter_op.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,15 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
4747
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
4848
auto *Ids = ctx.Input<Tensor>("Ids");
4949
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
50-
51-
// In place gradient: dX = dO
52-
dX->ShareDataWith(*dOut);
53-
dUpdates->mutable_data<T>(ctx.GetPlace());
54-
// Gradient by Gather: dUpdates = dO[Ids]
55-
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
50+
if (dX) {
51+
// In place gradient: dX = dO
52+
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
53+
}
54+
if (dUpdates) {
55+
dUpdates->mutable_data<T>(ctx.GetPlace());
56+
// Gradient by Gather: dUpdates = dO[Ids]
57+
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
58+
}
5659
}
5760
};
5861

paddle/fluid/operators/scatter_op.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,15 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
7474
auto *Ids = ctx.Input<Tensor>("Ids");
7575
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
7676

77-
// In place gradient: dX = dO
78-
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
79-
dUpdates->mutable_data<T>(ctx.GetPlace());
80-
// Gradient by Gather: dUpdates = dO[Ids]
81-
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
77+
if (dX) {
78+
// In place gradient: dX = dO
79+
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
80+
}
81+
if (dUpdates) {
82+
dUpdates->mutable_data<T>(ctx.GetPlace());
83+
// Gradient by Gather: dUpdates = dO[Ids]
84+
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
85+
}
8286
}
8387
};
8488

python/paddle/fluid/backward.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,125 @@ def _op_can_be_removed_(op_desc, no_grad_set):
247247
return op_descs
248248

249249

250+
def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
251+
"""
252+
Pruning Program with Structural Analysis Method of Computational Graph.
253+
The nodes of the computational graph composed of backward OPS should be
254+
interconnected. If there are unconnected sub-graphs in the computational graph,
255+
these sub-graphs should be cut off.
256+
257+
Args:
258+
grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs.
259+
forward_ops(list[Operator]): The forward ops.
260+
input_grad_names_set(set): this set is used to store the gradients' name
261+
which is generated by backward ops, and input_grad_names_set can help
262+
to prune the unnecessary backward ops.
263+
264+
Return:
265+
(list[core.OpDesc]): A list of OpDescs which should be pruned.
266+
"""
267+
268+
class Var(object):
269+
def __init__(self, var_name):
270+
self.var_name = var_name
271+
self.gen_op = None
272+
self.pendding_ops = []
273+
274+
def set_gen_op(self, gen_op):
275+
assert isinstance(gen_op, Op)
276+
assert self.gen_op is None
277+
self.gen_op = gen_op
278+
279+
def add_pending_op(self, op):
280+
assert isinstance(op, Op)
281+
self.pendding_ops.append(op)
282+
283+
class Op(object):
284+
def __init__(self, op_desc):
285+
self.op_desc = op_desc
286+
self.inputs = []
287+
self.outputs = []
288+
289+
def insert_input(self, var):
290+
assert isinstance(var, Var)
291+
self.inputs.append(var)
292+
293+
def insert_output(self, var):
294+
assert isinstance(var, Var)
295+
self.outputs.append(var)
296+
297+
var_versions = dict()
298+
299+
def _create_node(name):
300+
if name not in var_versions.keys():
301+
var_versions[name] = [Var(name)]
302+
else:
303+
var_versions[name].append(Var(name))
304+
return var_versions[name][-1]
305+
306+
def _create_or_get_last_version_node(name):
307+
if name not in var_versions.keys():
308+
var_versions[name] = [Var(name)]
309+
return var_versions[name][-1]
310+
311+
def _create_op_node(op_desc):
312+
op_node = Op(op_desc)
313+
for input in op_desc.input_arg_names():
314+
var = _create_or_get_last_version_node(name=input)
315+
var.add_pending_op(op_node)
316+
op_node.insert_input(var)
317+
for output in op_desc.output_arg_names():
318+
var = _create_node(name=output)
319+
var.set_gen_op(op_node)
320+
op_node.insert_output(var)
321+
return op_node
322+
323+
# Record the forward vars
324+
forward_vars_set = set() if input_grad_names_set is None else set(
325+
input_grad_names_set)
326+
for op in forward_ops:
327+
forward_vars_set.update(op.desc.input_arg_names())
328+
forward_vars_set.update(op.desc.output_arg_names())
329+
330+
# Record the vars which are created during backward and is not generated by op.
331+
backward_vars_set = set()
332+
# special_op_nodes is the candidate sub-graph head node.
333+
special_op_nodes = set()
334+
for op_desc in grad_op_descs:
335+
input_set = set(op_desc.input_arg_names())
336+
# The new_vars are created during backward and is not generated by op.
337+
new_vars = input_set - forward_vars_set - backward_vars_set
338+
backward_vars_set.update(op_desc.output_arg_names())
339+
340+
op_node = _create_op_node(op_desc)
341+
if len(new_vars) == len(input_set):
342+
special_op_nodes.add(op_node)
343+
344+
not_need_op_descs = []
345+
# Start traversing all candidate sub-graph headers to check whether
346+
# they are connected to backward computational graphs, and if they are
347+
# not, list them in not_need_op_descs
348+
for special_op_node in special_op_nodes:
349+
op_list = [special_op_node]
350+
ready_vars = set(special_op_node.inputs)
351+
remove_ops = True
352+
candidate_ops = [special_op_node]
353+
while len(candidate_ops) > 0:
354+
op_node = candidate_ops.pop(0)
355+
if _all_in_set_(op_node.inputs, ready_vars):
356+
for out_var in op_node.outputs:
357+
candidate_ops.extend(out_var.pendding_ops)
358+
op_list.extend(out_var.pendding_ops)
359+
ready_vars.update(op_node.outputs)
360+
else:
361+
remove_ops = False
362+
break
363+
if remove_ops:
364+
not_need_op_descs.extend([node.op_desc for node in op_list])
365+
366+
return set(not_need_op_descs)
367+
368+
250369
from .proto import framework_pb2
251370

252371

@@ -276,7 +395,10 @@ def _append_backward_ops_(block,
276395
grad_to_var(dict)(output argument):
277396
key(str): grad variable name
278397
val(str): corresponding forward variable name
279-
callback(callable object): a callable object used to decorate new generated grad ops
398+
callbacks(callable object): a callable object used to decorate new generated grad ops
399+
input_grad_names_set(set): this set is used to store the gradients' name which is
400+
generated by backward ops, and input_grad_names_set can help to prune the unnecessary
401+
backward ops.
280402
"""
281403
if callbacks is not None:
282404
assert (isinstance(callbacks, list))
@@ -342,6 +464,10 @@ def _append_backward_ops_(block,
342464
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
343465
no_grad_dict[block.idx])
344466

467+
not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set)
468+
grad_op_descs = [
469+
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
470+
]
345471
# append op_desc in grad_op_descs to target_block
346472
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
347473
backward = core.op_proto_and_checker_maker.OpRole.Backward
@@ -552,7 +678,9 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
552678

553679
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
554680
op_path = _find_op_path_(root_block, [loss], [], block_no_grad_set)
555-
681+
no_grad_vars = _find_no_grad_vars(root_block, op_path, [loss],
682+
block_no_grad_set)
683+
block_no_grad_set.update(no_grad_vars)
556684
no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set)))
557685

558686
input_grad_names_set = None
@@ -630,6 +758,26 @@ def _as_list(x):
630758
return list(x) if isinstance(x, collections.Sequence) else [x]
631759

632760

761+
def _find_no_grad_vars(block, op_path, targets, no_grad_set):
762+
"""
763+
Find the vars which is not used in the program, and
764+
those var belong to no_grad_var.
765+
"""
766+
output_names = set([out.name for out in targets])
767+
no_grad_var = []
768+
for i, op in reversed(list(enumerate(op_path))):
769+
# If the op has sub_block, it is too complicated to find the correct no_grad_var.
770+
if not op.has_attr("sub_block"):
771+
for out_var in op.desc.output_arg_names():
772+
if out_var not in output_names and out_var not in op.desc.input_arg_names(
773+
) and not block.vars[out_var].stop_gradient:
774+
no_grad_var.append(out_var)
775+
for name in op.desc.input_arg_names():
776+
if name not in no_grad_set:
777+
output_names.add(name)
778+
return set(no_grad_var)
779+
780+
633781
def _find_op_path_(block, outputs, inputs, no_grad_set):
634782
"""
635783
no_grad_set will also be changed

0 commit comments

Comments
 (0)