Skip to content

Commit dd8ee69

Browse files
authored
Merge pull request #11756 from typhoonzero/cherry_pick_bcast_fix
Merge pull request #11728 from typhoonzero/fix_paraexe_bcast
2 parents 40e83c4 + 75a71ba commit dd8ee69

File tree

6 files changed

+37
-13
lines changed

6 files changed

+37
-13
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
483483
}
484484
} else if (op.Type() == "concat") {
485485
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
486+
for (auto &varname : op.OutputArgumentNames()) {
487+
var_name_on_devices_.emplace(varname, op_dev_id);
488+
}
486489
} else {
487490
PADDLE_ENFORCE(
488491
"the distribute training related op should be in [split_byref, "

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class SSAGraphBuilder {
3030
SSAGraphBuilder() {}
3131
virtual ~SSAGraphBuilder() {}
3232
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
33-
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
33+
virtual int GetVarDeviceID(const std::string &var_name) const = 0;
3434

3535
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
3636

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
1818

19+
#include <string>
20+
1921
namespace paddle {
2022
namespace framework {
2123
namespace details {
@@ -33,6 +35,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
3335
return graph;
3436
}
3537

38+
int GetVarDeviceID(const std::string& var_name) const override {
39+
return builder_->GetVarDeviceID(var_name);
40+
}
41+
3642
bool IsValidGraph(const SSAGraph* graph) const;
3743

3844
private:

paddle/fluid/framework/details/ssa_graph_printer.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <iosfwd>
18+
#include <string>
1819
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
1920

2021
namespace paddle {
@@ -55,6 +56,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
5556
return graph;
5657
}
5758

59+
int GetVarDeviceID(const std::string& var_name) const override {
60+
return builder_->GetVarDeviceID(var_name);
61+
}
62+
5863
private:
5964
std::unique_ptr<SSAGraphPrinter> printer_;
6065
std::unique_ptr<SSAGraphBuilder> builder_;

paddle/fluid/framework/parallel_executor.cc

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,18 @@ ParallelExecutor::ParallelExecutor(
133133

134134
void ParallelExecutor::BCastParamsToGPUs(
135135
const std::unordered_set<std::string> &vars) const {
136-
// the the initialize bcast, all vars would be bcast from device(0), otherwise
136+
// the the initializing bcast, all vars would be bcast from device(0),
137+
// otherwise
137138
// bcast from the specified device.
138-
bool initialize = builder_.get() == nullptr ? true : false;
139+
bool initializing = builder_.get() == nullptr ? true : false;
139140

140141
for (auto &var : vars) {
141142
int var_dev_id =
142143
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
143-
if (!initialize && var_dev_id == -1) continue;
144+
if (!initializing && var_dev_id == -1) continue;
144145

145146
framework::Variable *main_var = nullptr;
146-
if (initialize) {
147+
if (initializing) {
147148
main_var = member_->local_scopes_[0]->FindVar(var);
148149
} else {
149150
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
@@ -164,7 +165,8 @@ void ParallelExecutor::BCastParamsToGPUs(
164165
auto place = member_->places_[i];
165166
void *buffer;
166167

167-
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) {
168+
if ((initializing && i == 0) ||
169+
(!initializing && static_cast<int>(i) == var_dev_id)) {
168170
buffer = const_cast<void *>(main_tensor.data<void>());
169171
} else {
170172
auto local_scope = member_->local_scopes_[i];
@@ -181,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs(
181183
platform::NCCLGroupGuard guard;
182184
for (size_t i = 0; i < member_->places_.size(); ++i) {
183185
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
184-
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
185-
nccl_ctx.comm_, nccl_ctx.stream());
186+
if (initializing) {
187+
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
188+
nccl_ctx.comm_, nccl_ctx.stream());
189+
} else {
190+
if (var_dev_id >= 0) {
191+
platform::dynload::ncclBcast(buffers[i], numel, data_type,
192+
var_dev_id, nccl_ctx.comm_,
193+
nccl_ctx.stream());
194+
}
195+
}
186196
}
187197
member_->nccl_ctxs_->WaitAll();
188198
}

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ def get_trainer_program(self):
302302
"""
303303
# remove optimize ops and add a send op to main_program
304304
delete_ops(self.origin_program.global_block(), self.optimize_ops)
305-
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
306305
self.origin_program.__str__()
307306
return self.origin_program
308307

@@ -383,11 +382,12 @@ def get_pserver_program(self, endpoint):
383382
if self._is_adam_connected_op(op):
384383
global_ops.append(op)
385384

386-
def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
385+
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
386+
lr_ops):
387387
if self._is_optimizer_op(op):
388388
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
389389
self.origin_program, merged_var)
390-
else:
390+
elif op not in lr_ops:
391391
self._append_pserver_non_opt_ops(block, op)
392392

393393
def __op_have_grad_input__(op):
@@ -447,15 +447,15 @@ def __clone_lr_op_sub_block__(op, program, new_block):
447447
# optimizer is connected to itself
448448
if ufind.is_connected(op, opt_op) and op not in global_ops:
449449
__append_optimize_op__(op, per_opt_block, grad_to_block_id,
450-
merged_var)
450+
merged_var, lr_ops)
451451

452452
# append global ops
453453
if global_ops:
454454
opt_state_block = pserver_program.create_block(
455455
pserver_program.num_blocks - 1)
456456
for glb_op in global_ops:
457457
__append_optimize_op__(glb_op, opt_state_block,
458-
grad_to_block_id, None)
458+
grad_to_block_id, None, lr_ops)
459459

460460
# process distributed lookup_table
461461
prefetch_var_name_to_block_id = []

0 commit comments

Comments
 (0)