Skip to content

Commit a3ca4c9

Browse files
committed
fix loss.gradVar
1 parent 8c6dae7 commit a3ca4c9

File tree

3 files changed

+2
-10
lines changed

3 files changed

+2
-10
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
145145
} else if (IsDistTrainOp(*op, send_op)) {
146146
CreateComputationalOps(&result, *op, 1);
147147
} else if (IsScaleLossOp(*op)) {
148-
CreateComputationalOps(&result, *op, places_.size());
149148
// user can customize loss@grad if not use_default_grad_scale_
150149
if (use_default_grad_scale_) {
151150
CreateScaleLossGradOp(&result);
152151
}
153152
is_forwarding = false;
154153
} else {
155-
if (IsScaleLossGradOp(*op)) continue;
156154
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
157155
if (op_dev_id == -1) { // var on all device
158156
CreateComputationalOps(&result, *op, places_.size());
@@ -401,12 +399,6 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
401399
}
402400

403401
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
404-
// FIXME(yy): Do not hard code like this
405-
return op.OutputArgumentNames().size() == 1 &&
406-
(op.OutputArgumentNames()[0]) == loss_var_name_;
407-
}
408-
409-
bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const {
410402
// FIXME(yy): Do not hard code like this
411403
return op.OutputArgumentNames().size() == 1 &&
412404
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6767

6868
bool IsScaleLossOp(const OpDesc &op) const;
6969

70-
bool IsScaleLossGradOp(const OpDesc &op) const;
71-
7270
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
7371

7472
/**

python/paddle/fluid/backward.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
480480

481481
program.current_block_idx = current_block_idx
482482
program.sync_with_cpp()
483+
# FIXME(zcd): prevent loss.grad optimized by mem_opt.
484+
loss.block.var(_append_grad_suffix_(loss.name)).persistable = True
483485

484486
if parameter_list is not None:
485487
parameters = parameter_list

0 commit comments

Comments
 (0)