Skip to content

Commit cb79b02

Browse files
authored
Merge pull request #12595 from reyoung/fix_scale_loss_with_memopt
Fix bug when memopt optimize loss.grad and use ParallelExecutor
2 parents 46fe9ba + c4f8afa commit cb79b02

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
275275
if (strategy_.gradient_scale_ !=
276276
BuildStrategy::GradientScaleStrategy::kCustomized) {
277277
// TODO(paddle-dev): Why is there no input for this op_handle?
278-
CreateScaleLossGradOp(&result);
278+
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
279+
CreateScaleLossGradOp(&result, loss_grad_name);
279280
}
280281
// This assumes the backward generating code will ensure IsScaleLossOp
281282
// is true only for the op that scale the final scalar loss.
@@ -535,7 +536,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
535536
return got == sharded_var_device.end() ? -1 : got->second;
536537
}
537538

538-
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
539+
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
540+
ir::Graph *result, const std::string &loss_grad_name) const {
539541
for (size_t i = 0; i < places_.size(); ++i) {
540542
// Insert ScaleCost OpHandle
541543
#ifdef PADDLE_WITH_CUDA
@@ -558,10 +560,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
558560
// loss->pending_ops_.emplace_back(op_handle);
559561
// op_handle->inputs_.emplace_back(loss);
560562

561-
CreateOpOutput(result, op_handle,
562-
result->CreateEmptyNode(GradVarName(loss_var_name_),
563-
ir::Node::Type::kVariable),
564-
places_[i], i);
563+
CreateOpOutput(
564+
result, op_handle,
565+
result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable),
566+
places_[i], i);
565567
}
566568
}
567569

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
7575
void CreateComputationalOps(ir::Graph *result, ir::Node *node,
7676
size_t num_places) const;
7777

78-
void CreateScaleLossGradOp(ir::Graph *result) const;
78+
void CreateScaleLossGradOp(ir::Graph *result,
79+
const std::string &loss_grad_name) const;
80+
7981
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
8082
int dst_dev_id) const;
8183
void CreateComputationalOp(ir::Graph *result, ir::Node *node,

paddle/fluid/framework/operator.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
778778
const ExecutionContext& ctx) const {
779779
auto& scope = ctx.scope();
780780
int data_type = -1;
781+
std::string last_input_name;
781782
for (auto& input : this->inputs_) {
782783
for (auto& ipt_name : input.second) {
783784
auto* var = scope.FindVar(ipt_name);
@@ -794,9 +795,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
794795
int tmp = static_cast<int>(ToDataType(t->type()));
795796
PADDLE_ENFORCE(
796797
tmp == data_type || data_type == -1,
797-
"DataType of Paddle Op %s must be the same. Get %d != %d", Type(),
798-
data_type, tmp);
798+
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
799+
Type(), last_input_name, data_type, ipt_name, tmp);
799800
data_type = tmp;
801+
last_input_name = ipt_name;
800802
}
801803
}
802804
}

0 commit comments

Comments
 (0)