Skip to content

Commit a83a4fa

Browse files
authored
Merge pull request #13441 from panyx0718/ir2
simplify and hide bcast_params
2 parents 2d89849 + e5b3220 commit a83a4fa

File tree

4 files changed

+18
-50
lines changed

4 files changed

+18
-50
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ static const char kLocalScopes[] = "local_scopes";
127127
static const char kStrategy[] = "strategy";
128128

129129
void MultiDevSSAGraphBuilder::Init() const {
130+
all_vars_.clear();
131+
balance_vars_.clear();
132+
130133
loss_var_name_ = Get<const std::string>(kLossVarName);
131134
places_ = Get<const std::vector<platform::Place>>(kPlaces);
132135
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);

paddle/fluid/framework/details/multi_devices_graph_pass.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
4040
size_t device_id) const;
4141
void Init() const;
4242

43-
private:
44-
mutable std::string loss_var_name_;
45-
mutable std::vector<platform::Place> places_;
46-
mutable std::vector<Scope *> local_scopes_;
47-
mutable std::unordered_set<std::string> grad_names_;
48-
4943
#ifdef PADDLE_WITH_CUDA
5044
mutable platform::NCCLContextMap *nccl_ctxs_;
5145
#endif
@@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
9589
size_t GetAppropriateDeviceID(
9690
const std::vector<std::string> &var_names) const;
9791

98-
private:
92+
void SetCommunicationContext(OpHandleBase *op_handle,
93+
const platform::Place &p) const;
94+
95+
mutable std::string loss_var_name_;
96+
mutable std::vector<platform::Place> places_;
97+
mutable std::vector<Scope *> local_scopes_;
98+
mutable std::unordered_set<std::string> grad_names_;
99+
99100
mutable BuildStrategy strategy_;
100101
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
101102
mutable std::vector<int64_t> balance_vars_;
102-
103-
void SetCommunicationContext(OpHandleBase *op_handle,
104-
const platform::Place &p) const;
105103
};
106104
} // namespace details
107105
} // namespace framework

paddle/fluid/framework/parallel_executor.cc

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -233,30 +233,9 @@ ParallelExecutor::ParallelExecutor(
233233

234234
void ParallelExecutor::BCastParamsToDevices(
235235
const std::unordered_set<std::string> &vars) const {
236-
// the initializing bcast, all vars would be bcast from device(0),
237-
// otherwise
238-
// bcast from the specified device.
239-
bool initializing = member_->executor_ ? false : true;
236+
// the initializing bcast, all vars would be bcast from device(0).
240237
for (auto &var : vars) {
241-
int var_dev_id = -1;
242-
if (member_->executor_) {
243-
auto &sharded_var_device =
244-
member_->executor_->Graph().Get<details::ShardedVarDevice>(
245-
details::kShardedVarDevice);
246-
if (sharded_var_device.find(var) != sharded_var_device.end()) {
247-
var_dev_id = sharded_var_device.at(var);
248-
}
249-
}
250-
251-
if (!initializing && var_dev_id == -1) continue;
252-
253-
framework::Variable *main_var = nullptr;
254-
if (initializing) {
255-
main_var = member_->local_scopes_[0]->FindVar(var);
256-
} else {
257-
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
258-
}
259-
238+
framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
260239
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
261240
continue;
262241
}
@@ -272,8 +251,7 @@ void ParallelExecutor::BCastParamsToDevices(
272251
auto place = member_->places_[i];
273252
void *buffer;
274253

275-
if ((initializing && i == 0) ||
276-
(!initializing && static_cast<int>(i) == var_dev_id)) {
254+
if (i == 0) {
277255
buffer = const_cast<void *>(main_tensor.data<void>());
278256
} else {
279257
auto local_scope = member_->local_scopes_[i];
@@ -290,29 +268,18 @@ void ParallelExecutor::BCastParamsToDevices(
290268
platform::NCCLGroupGuard guard;
291269
for (size_t i = 0; i < member_->places_.size(); ++i) {
292270
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
293-
if (initializing) {
294-
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
295-
nccl_ctx.comm_, nccl_ctx.stream());
296-
} else {
297-
if (var_dev_id >= 0) {
298-
platform::dynload::ncclBcast(buffers[i], numel, data_type,
299-
var_dev_id, nccl_ctx.comm_,
300-
nccl_ctx.stream());
301-
}
302-
}
271+
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
272+
nccl_ctx.comm_, nccl_ctx.stream());
303273
}
304274
member_->nccl_ctxs_->WaitAll();
305275
}
306-
307276
#else
308277
PADDLE_THROW("Not compiled with CUDA");
309278
#endif
310279
} else {
311280
platform::CPUPlace cpu;
312281
for (size_t i = 0; i < member_->places_.size(); ++i) {
313-
if ((initializing && i == 0) ||
314-
(!initializing && static_cast<int>(i) == var_dev_id))
315-
continue;
282+
if (i == 0) continue;
316283

317284
auto local_scope = member_->local_scopes_[i];
318285
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();

paddle/fluid/framework/parallel_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ class ParallelExecutor {
7272
void Run(const std::vector<std::string> &fetch_tensors,
7373
const std::string &fetched_var_name);
7474

75+
private:
7576
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
7677

77-
private:
7878
ParallelExecutorPrivate *member_;
7979

8080
#ifdef PADDLE_WITH_CUDA

0 commit comments

Comments
 (0)