Skip to content

Commit 8c6dae7

Browse files
committed
fix pe bug
1 parent 8231960 commit 8c6dae7

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
145145
} else if (IsDistTrainOp(*op, send_op)) {
146146
CreateComputationalOps(&result, *op, 1);
147147
} else if (IsScaleLossOp(*op)) {
148+
CreateComputationalOps(&result, *op, places_.size());
148149
// user can customize loss@grad if not use_default_grad_scale_
149150
if (use_default_grad_scale_) {
150151
CreateScaleLossGradOp(&result);
151152
}
152153
is_forwarding = false;
153154
} else {
155+
if (IsScaleLossGradOp(*op)) continue;
154156
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
155157
if (op_dev_id == -1) { // var on all device
156158
CreateComputationalOps(&result, *op, places_.size());
@@ -399,6 +401,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
399401
}
400402

401403
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
404+
// FIXME(yy): Do not hard code like this
405+
return op.OutputArgumentNames().size() == 1 &&
406+
(op.OutputArgumentNames()[0]) == loss_var_name_;
407+
}
408+
409+
bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const {
402410
// FIXME(yy): Do not hard code like this
403411
return op.OutputArgumentNames().size() == 1 &&
404412
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6767

6868
bool IsScaleLossOp(const OpDesc &op) const;
6969

70+
bool IsScaleLossGradOp(const OpDesc &op) const;
71+
7072
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
7173

7274
/**

0 commit comments

Comments
 (0)