@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
57
57
for (auto &p : params) {
58
58
grad_names_.insert (GradVarName (p));
59
59
}
60
+ balance_vars_.resize (places_.size (), 0 );
60
61
}
61
62
62
63
void MultiDevSSAGraphBuilder::CreateOpHandleIOs (SSAGraph *result,
@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
140
141
checker (op.InputArgumentNames (), recv_vars);
141
142
}
142
143
144
+ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID (
145
+ const std::vector<std::string> &var_names) const {
146
+ int64_t numel_sum = 0 ;
147
+ for (auto var_name : var_names) {
148
+ auto var_desc = all_vars_.at (var_name);
149
+ PADDLE_ENFORCE_NOT_NULL (var_desc);
150
+ auto dim = framework::make_ddim (var_desc->GetShape ());
151
+ int64_t numel = framework::product (dim);
152
+ PADDLE_ENFORCE_GT (numel, 0 );
153
+ numel_sum += numel;
154
+ }
155
+
156
+ auto smallest =
157
+ std::min_element (std::begin (balance_vars_), std::end (balance_vars_));
158
+ size_t dev_id =
159
+ static_cast <size_t >(std::distance (std::begin (balance_vars_), smallest));
160
+ balance_vars_[dev_id] += numel_sum;
161
+ return dev_id;
162
+ }
163
+
143
164
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build (
144
165
const ProgramDesc &program) const {
145
- std::unordered_map<std::string, VarDesc *> all_vars;
146
166
for (auto *var : program.Block (0 ).AllVars ()) {
147
- all_vars[ var->Name ()] = var;
167
+ all_vars_. emplace ( var->Name (), var) ;
148
168
}
149
169
150
170
auto graph = new SSAGraph ();
@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
161
181
auto send_vars = FindDistTrainSendVars (program);
162
182
auto recv_vars = FindDistTrainRecvVars (program);
163
183
164
- std::vector<std::unordered_set<std::string>> var_name_on_devices;
165
184
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
166
- var_name_on_devices.resize (places_.size ());
167
185
bcast_var_name_set.resize (places_.size ());
168
186
169
187
size_t cur_device_id = 0 ;
170
- std::vector<int64_t > balance_grads (places_.size (), 0 );
171
-
172
- auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
173
- auto var_desc = all_vars.at (g_name);
174
- PADDLE_ENFORCE_NOT_NULL (var_desc);
175
- auto dim = framework::make_ddim (var_desc->GetShape ());
176
- int64_t numel = framework::product (dim);
177
- PADDLE_ENFORCE_GE (numel, 0 );
178
- auto smallest =
179
- std::min_element (std::begin (balance_grads), std::end (balance_grads));
180
- size_t dev_id =
181
- static_cast <size_t >(std::distance (std::begin (balance_grads), smallest));
182
- balance_grads[dev_id] += numel;
183
- return dev_id;
184
- };
185
-
186
188
bool is_forwarding = true ;
189
+
187
190
for (auto *op : program.Block (0 ).AllOps ()) {
188
191
if (boost::get<int >(
189
192
op->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ())) ==
190
193
static_cast <int >(OpRole::kRPC )) {
191
- // append rpc op if program is distributed trainer main program.
192
- // always use the first device
193
194
CreateRPCOp (&result, *op);
194
195
} else if (IsDistTrainOp (*op, send_vars, recv_vars)) {
195
196
CreateDistTrainOp (&result, *op);
@@ -201,13 +202,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
201
202
}
202
203
is_forwarding = false ;
203
204
} else {
204
- int op_dev_id = GetOpDeviceID (var_name_on_devices, *op);
205
+ int op_dev_id = GetOpDeviceID (*op);
205
206
if (op_dev_id == -1 ) { // var on all device
206
207
CreateComputationalOps (&result, *op, places_.size ());
207
208
} else {
208
209
CreateComputationalOp (&result, *op, op_dev_id);
209
210
for (auto &var_name : op->OutputArgumentNames ()) {
210
- var_name_on_devices[op_dev_id] .emplace (var_name);
211
+ var_name_on_devices_ .emplace (var_name, op_dev_id );
211
212
}
212
213
}
213
214
if (!is_forwarding && places_.size () > 1 ) {
@@ -230,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
230
231
231
232
switch (strategy_.reduce_ ) {
232
233
case BuildStrategy::ReduceStrategy::kReduce :
233
- cur_device_id = get_appropriate_dev ( g_name);
234
+ cur_device_id = GetAppropriateDeviceID ({ g_name} );
234
235
CreateReduceOp (&result, g_name, cur_device_id);
235
- var_name_on_devices[cur_device_id] .emplace (g_name);
236
+ var_name_on_devices_ .emplace (g_name, cur_device_id );
236
237
bcast_var_name_set[cur_device_id].emplace (p_name);
237
238
break ;
238
239
case BuildStrategy::ReduceStrategy::kAllReduce :
239
- if (IsSparseGradient (all_vars, g_name)) {
240
+ if (IsSparseGradient (g_name)) {
240
241
CreateReduceOp (&result, g_name, 0 );
241
242
CreateBroadcastOp (&result, g_name, 0 );
242
243
} else {
@@ -273,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
273
274
return std::unique_ptr<SSAGraph>(graph);
274
275
}
275
276
276
- bool MultiDevSSAGraphBuilder::IsSparseGradient (
277
- const std::unordered_map<std::string, VarDesc *> &all_vars,
278
- const std::string &og) const {
279
- PADDLE_ENFORCE (all_vars.count (og) != 0 );
280
- if (all_vars.at (og)->GetType () == proto::VarType::SELECTED_ROWS) {
277
+ bool MultiDevSSAGraphBuilder::IsSparseGradient (const std::string &og) const {
278
+ PADDLE_ENFORCE (all_vars_.count (og) != 0 );
279
+ if (all_vars_.at (og)->GetType () == proto::VarType::SELECTED_ROWS) {
281
280
return true ;
282
281
}
283
282
return false ;
@@ -363,24 +362,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
363
362
return is_pg_once;
364
363
}
365
364
366
- int MultiDevSSAGraphBuilder::GetOpDeviceID (
367
- const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
368
- const OpDesc &op) const {
365
+ int MultiDevSSAGraphBuilder::GetOpDeviceID (const OpDesc &op) const {
369
366
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce ) {
370
367
return -1 ;
371
368
}
372
369
373
- int var_dev_id = -1 ;
374
- for (auto &var_name : op.InputArgumentNames ()) {
375
- if (var_dev_id != -1 ) break ;
376
- for (size_t i = 0 ; i < var_name_on_devices.size (); ++i) {
377
- if (var_name_on_devices[i].count (var_name)) {
378
- var_dev_id = static_cast <int >(i);
379
- break ;
380
- }
370
+ for (auto &varname : op.InputArgumentNames ()) {
371
+ int dev_id = GetVarDeviceID (varname);
372
+ if (dev_id != -1 ) {
373
+ return dev_id;
381
374
}
382
375
}
383
- return var_dev_id;
376
+ return -1 ;
377
+ }
378
+
379
+ int MultiDevSSAGraphBuilder::GetVarDeviceID (const std::string &varname) const {
380
+ auto got = var_name_on_devices_.find (varname);
381
+ return got == var_name_on_devices_.end () ? -1 : got->second ;
384
382
}
385
383
386
384
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp (SSAGraph *result) const {
@@ -463,16 +461,65 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
463
461
464
462
void MultiDevSSAGraphBuilder::CreateDistTrainOp (SSAGraph *result,
465
463
const OpDesc &op) const {
466
- CreateComputationalOp (result, op, 0 );
464
+ int op_dev_id = -1 ;
465
+ if (op.Type () == " split_byref" ) {
466
+ op_dev_id = GetVarDeviceID (op.InputArgumentNames ()[0 ]);
467
+ if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce ) {
468
+ op_dev_id = GetAppropriateDeviceID (op.InputArgumentNames ());
469
+ for (auto &varname : op.InputArgumentNames ()) {
470
+ var_name_on_devices_.emplace (varname, op_dev_id);
471
+ }
472
+ }
473
+ for (auto &varname : op.OutputArgumentNames ()) {
474
+ var_name_on_devices_.emplace (varname, op_dev_id);
475
+ }
476
+ } else if (op.Type () == " concat" ) {
477
+ op_dev_id = GetVarDeviceID (op.InputArgumentNames ()[0 ]);
478
+ } else {
479
+ PADDLE_ENFORCE (
480
+ " the distribute training related op should be in [split_byref, "
481
+ " concat]." );
482
+ }
483
+
484
+ PADDLE_ENFORCE (op_dev_id != -1 ,
485
+ " can not find right place for distributed op: %s" , op.Type ());
486
+
487
+ CreateComputationalOp (result, op, op_dev_id);
467
488
if (op.Type () == " concat" ) {
468
489
ConnectOp (result, result->ops_ .back ().get (), " fetch_barrier" );
469
490
}
470
491
}
471
492
472
493
void MultiDevSSAGraphBuilder::CreateRPCOp (SSAGraph *result,
473
494
const OpDesc &op) const {
474
- result->ops_ .emplace_back (
475
- new RPCOpHandle (op, local_scopes_[0 ], op.Type (), places_[0 ]));
495
+ int op_dev_id = -1 ;
496
+ if (op.Type () == " send" ) {
497
+ op_dev_id = GetVarDeviceID (op.InputArgumentNames ()[0 ]);
498
+ // the variable name which contains .block means it was splited by
499
+ // split_byref op
500
+ // so that we can balance the variable blocks to all the pserver instances.
501
+ if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
502
+ op.InputArgumentNames ()[0 ].find (" .block" ) == std::string::npos) {
503
+ op_dev_id = GetAppropriateDeviceID (op.InputArgumentNames ());
504
+ for (auto &varname : op.InputArgumentNames ()) {
505
+ var_name_on_devices_.emplace (varname, op_dev_id);
506
+ }
507
+ }
508
+ } else if (op.Type () == " recv" ) {
509
+ op_dev_id = GetAppropriateDeviceID (op.OutputArgumentNames ());
510
+ for (auto &varname : op.OutputArgumentNames ()) {
511
+ var_name_on_devices_.emplace (varname, op_dev_id);
512
+ }
513
+ } else {
514
+ // send_barrier and fetch_barrier op can be scheduled on device 0
515
+ op_dev_id = 0 ;
516
+ }
517
+
518
+ PADDLE_ENFORCE (op_dev_id != -1 , " can not find the right place for rpc op: %s" ,
519
+ op.Type ());
520
+
521
+ result->ops_ .emplace_back (new RPCOpHandle (op, local_scopes_[op_dev_id],
522
+ op.Type (), places_[op_dev_id]));
476
523
477
524
if (op.Type () == " send_barrier" ) {
478
525
ConnectOp (result, result->ops_ .back ().get (), " send" );
@@ -488,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
488
535
" send, send_barrier. recv, fetch_barrier]" );
489
536
}
490
537
491
- // TODO(Yancey1989): schedule rpc op on different place may
492
- // increate throughput
493
- CreateOpHandleIOs (result, op, 0 );
538
+ CreateOpHandleIOs (result, op, op_dev_id);
494
539
}
495
540
496
541
bool MultiDevSSAGraphBuilder::IsScaleLossOp (const OpDesc &op) const {
0 commit comments