@@ -209,30 +209,9 @@ ParallelExecutor::ParallelExecutor(
209
209
210
210
void ParallelExecutor::BCastParamsToDevices (
211
211
const std::unordered_set<std::string> &vars) const {
212
- // the initializing bcast, all vars would be bcast from device(0),
213
- // otherwise
214
- // bcast from the specified device.
215
- bool initializing = member_->executor_ ? false : true ;
212
+ // the initializing bcast, all vars would be bcast from device(0).
216
213
for (auto &var : vars) {
217
- int var_dev_id = -1 ;
218
- if (member_->executor_ ) {
219
- auto &sharded_var_device =
220
- member_->executor_ ->Graph ().Get <details::ShardedVarDevice>(
221
- details::kShardedVarDevice );
222
- if (sharded_var_device.find (var) != sharded_var_device.end ()) {
223
- var_dev_id = sharded_var_device.at (var);
224
- }
225
- }
226
-
227
- if (!initializing && var_dev_id == -1 ) continue ;
228
-
229
- framework::Variable *main_var = nullptr ;
230
- if (initializing) {
231
- main_var = member_->local_scopes_ [0 ]->FindVar (var);
232
- } else {
233
- main_var = member_->local_scopes_ [var_dev_id]->FindVar (var);
234
- }
235
-
214
+ framework::Variable *main_var = member_->local_scopes_ [0 ]->FindVar (var);
236
215
if (main_var == nullptr || !main_var->IsType <LoDTensor>()) {
237
216
continue ;
238
217
}
@@ -248,8 +227,7 @@ void ParallelExecutor::BCastParamsToDevices(
248
227
auto place = member_->places_ [i];
249
228
void *buffer;
250
229
251
- if ((initializing && i == 0 ) ||
252
- (!initializing && static_cast <int >(i) == var_dev_id)) {
230
+ if (i == 0 ) {
253
231
buffer = const_cast <void *>(main_tensor.data <void >());
254
232
} else {
255
233
auto local_scope = member_->local_scopes_ [i];
@@ -266,29 +244,18 @@ void ParallelExecutor::BCastParamsToDevices(
266
244
platform::NCCLGroupGuard guard;
267
245
for (size_t i = 0 ; i < member_->places_ .size (); ++i) {
268
246
auto &nccl_ctx = member_->nccl_ctxs_ ->at (member_->places_ [i]);
269
- if (initializing) {
270
- platform::dynload::ncclBcast (buffers[i], numel, data_type, 0 ,
271
- nccl_ctx.comm_ , nccl_ctx.stream ());
272
- } else {
273
- if (var_dev_id >= 0 ) {
274
- platform::dynload::ncclBcast (buffers[i], numel, data_type,
275
- var_dev_id, nccl_ctx.comm_ ,
276
- nccl_ctx.stream ());
277
- }
278
- }
247
+ platform::dynload::ncclBcast (buffers[i], numel, data_type, 0 ,
248
+ nccl_ctx.comm_ , nccl_ctx.stream ());
279
249
}
280
250
member_->nccl_ctxs_ ->WaitAll ();
281
251
}
282
-
283
252
#else
284
253
PADDLE_THROW (" Not compiled with CUDA" );
285
254
#endif
286
255
} else {
287
256
platform::CPUPlace cpu;
288
257
for (size_t i = 0 ; i < member_->places_ .size (); ++i) {
289
- if ((initializing && i == 0 ) ||
290
- (!initializing && static_cast <int >(i) == var_dev_id))
291
- continue ;
258
+ if (i == 0 ) continue ;
292
259
293
260
auto local_scope = member_->local_scopes_ [i];
294
261
auto *t = local_scope->Var (var)->GetMutable <LoDTensor>();
0 commit comments