@@ -233,30 +233,9 @@ ParallelExecutor::ParallelExecutor(
233
233
234
234
void ParallelExecutor::BCastParamsToDevices (
235
235
const std::unordered_set<std::string> &vars) const {
236
- // the initializing bcast, all vars would be bcast from device(0),
237
- // otherwise
238
- // bcast from the specified device.
239
- bool initializing = member_->executor_ ? false : true ;
236
+ // the initializing bcast, all vars would be bcast from device(0).
240
237
for (auto &var : vars) {
241
- int var_dev_id = -1 ;
242
- if (member_->executor_ ) {
243
- auto &sharded_var_device =
244
- member_->executor_ ->Graph ().Get <details::ShardedVarDevice>(
245
- details::kShardedVarDevice );
246
- if (sharded_var_device.find (var) != sharded_var_device.end ()) {
247
- var_dev_id = sharded_var_device.at (var);
248
- }
249
- }
250
-
251
- if (!initializing && var_dev_id == -1 ) continue ;
252
-
253
- framework::Variable *main_var = nullptr ;
254
- if (initializing) {
255
- main_var = member_->local_scopes_ [0 ]->FindVar (var);
256
- } else {
257
- main_var = member_->local_scopes_ [var_dev_id]->FindVar (var);
258
- }
259
-
238
+ framework::Variable *main_var = member_->local_scopes_ [0 ]->FindVar (var);
260
239
if (main_var == nullptr || !main_var->IsType <LoDTensor>()) {
261
240
continue ;
262
241
}
@@ -272,8 +251,7 @@ void ParallelExecutor::BCastParamsToDevices(
272
251
auto place = member_->places_ [i];
273
252
void *buffer;
274
253
275
- if ((initializing && i == 0 ) ||
276
- (!initializing && static_cast <int >(i) == var_dev_id)) {
254
+ if (i == 0 ) {
277
255
buffer = const_cast <void *>(main_tensor.data <void >());
278
256
} else {
279
257
auto local_scope = member_->local_scopes_ [i];
@@ -290,29 +268,18 @@ void ParallelExecutor::BCastParamsToDevices(
290
268
platform::NCCLGroupGuard guard;
291
269
for (size_t i = 0 ; i < member_->places_ .size (); ++i) {
292
270
auto &nccl_ctx = member_->nccl_ctxs_ ->at (member_->places_ [i]);
293
- if (initializing) {
294
- platform::dynload::ncclBcast (buffers[i], numel, data_type, 0 ,
295
- nccl_ctx.comm_ , nccl_ctx.stream ());
296
- } else {
297
- if (var_dev_id >= 0 ) {
298
- platform::dynload::ncclBcast (buffers[i], numel, data_type,
299
- var_dev_id, nccl_ctx.comm_ ,
300
- nccl_ctx.stream ());
301
- }
302
- }
271
+ platform::dynload::ncclBcast (buffers[i], numel, data_type, 0 ,
272
+ nccl_ctx.comm_ , nccl_ctx.stream ());
303
273
}
304
274
member_->nccl_ctxs_ ->WaitAll ();
305
275
}
306
-
307
276
#else
308
277
PADDLE_THROW (" Not compiled with CUDA" );
309
278
#endif
310
279
} else {
311
280
platform::CPUPlace cpu;
312
281
for (size_t i = 0 ; i < member_->places_ .size (); ++i) {
313
- if ((initializing && i == 0 ) ||
314
- (!initializing && static_cast <int >(i) == var_dev_id))
315
- continue ;
282
+ if (i == 0 ) continue ;
316
283
317
284
auto local_scope = member_->local_scopes_ [i];
318
285
auto *t = local_scope->Var (var)->GetMutable <LoDTensor>();
0 commit comments