Skip to content

Commit ec6ee0a

Browse files
committed
simplify and hide bcast_params
1 parent d9297c1 commit ec6ee0a

File tree

2 files changed

+7
-40
lines changed

2 files changed

+7
-40
lines changed

paddle/fluid/framework/parallel_executor.cc

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -209,30 +209,9 @@ ParallelExecutor::ParallelExecutor(
209209

210210
void ParallelExecutor::BCastParamsToDevices(
211211
const std::unordered_set<std::string> &vars) const {
212-
// the initializing bcast, all vars would be bcast from device(0),
213-
// otherwise
214-
// bcast from the specified device.
215-
bool initializing = member_->executor_ ? false : true;
212+
// the initializing bcast, all vars would be bcast from device(0).
216213
for (auto &var : vars) {
217-
int var_dev_id = -1;
218-
if (member_->executor_) {
219-
auto &sharded_var_device =
220-
member_->executor_->Graph().Get<details::ShardedVarDevice>(
221-
details::kShardedVarDevice);
222-
if (sharded_var_device.find(var) != sharded_var_device.end()) {
223-
var_dev_id = sharded_var_device.at(var);
224-
}
225-
}
226-
227-
if (!initializing && var_dev_id == -1) continue;
228-
229-
framework::Variable *main_var = nullptr;
230-
if (initializing) {
231-
main_var = member_->local_scopes_[0]->FindVar(var);
232-
} else {
233-
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
234-
}
235-
214+
framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
236215
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
237216
continue;
238217
}
@@ -248,8 +227,7 @@ void ParallelExecutor::BCastParamsToDevices(
248227
auto place = member_->places_[i];
249228
void *buffer;
250229

251-
if ((initializing && i == 0) ||
252-
(!initializing && static_cast<int>(i) == var_dev_id)) {
230+
if (i == 0) {
253231
buffer = const_cast<void *>(main_tensor.data<void>());
254232
} else {
255233
auto local_scope = member_->local_scopes_[i];
@@ -266,29 +244,18 @@ void ParallelExecutor::BCastParamsToDevices(
266244
platform::NCCLGroupGuard guard;
267245
for (size_t i = 0; i < member_->places_.size(); ++i) {
268246
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
269-
if (initializing) {
270-
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
271-
nccl_ctx.comm_, nccl_ctx.stream());
272-
} else {
273-
if (var_dev_id >= 0) {
274-
platform::dynload::ncclBcast(buffers[i], numel, data_type,
275-
var_dev_id, nccl_ctx.comm_,
276-
nccl_ctx.stream());
277-
}
278-
}
247+
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
248+
nccl_ctx.comm_, nccl_ctx.stream());
279249
}
280250
member_->nccl_ctxs_->WaitAll();
281251
}
282-
283252
#else
284253
PADDLE_THROW("Not compiled with CUDA");
285254
#endif
286255
} else {
287256
platform::CPUPlace cpu;
288257
for (size_t i = 0; i < member_->places_.size(); ++i) {
289-
if ((initializing && i == 0) ||
290-
(!initializing && static_cast<int>(i) == var_dev_id))
291-
continue;
258+
if (i == 0) continue;
292259

293260
auto local_scope = member_->local_scopes_[i];
294261
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();

paddle/fluid/framework/parallel_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ class ParallelExecutor {
6666
void Run(const std::vector<std::string> &fetch_tensors,
6767
const std::string &fetched_var_name);
6868

69+
private:
6970
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
7071

71-
private:
7272
ParallelExecutorPrivate *member_;
7373
};
7474

0 commit comments

Comments
 (0)