Skip to content

Commit 12ac073

Browse files
authored
Merge pull request #13080 from reyoung/fix_scale_grad_bug
Fix bug when loss@GRAD is reused. (Release Branch)
2 parents e860a5d + 941e835 commit 12ac073

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
625625
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
626626
ir::Graph *result, const std::string &loss_grad_name) const {
627627
for (size_t i = 0; i < places_.size(); ++i) {
628-
// Insert ScaleCost OpHandle
629-
#ifdef PADDLE_WITH_CUDA
630-
auto *communication_dev_ctx =
631-
nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
632-
: platform::DeviceContextPool::Instance().Get(places_[i]);
633-
#else
634-
auto *communication_dev_ctx =
635-
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
636-
#endif
628+
// Insert ScaleCost OpHandle
629+
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
637630
auto *op_handle = new ScaleLossGradOpHandle(
638631
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
639-
local_scopes_.size(), local_scopes_[i], places_[i],
640-
communication_dev_ctx);
632+
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx);
641633
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
642634

643635
// FIXME: Currently ScaleLossGradOp only use device_count as scale

0 commit comments

Comments
 (0)