@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
57
57
for (auto &p : params) {
58
58
grad_names_.insert (GradVarName (p));
59
59
}
60
+ balance_vars_.resize (places_.size (), 0 );
60
61
}
61
62
62
63
void MultiDevSSAGraphBuilder::CreateOpHandleIOs (SSAGraph *result,
@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
140
141
checker (op.InputArgumentNames (), recv_vars);
141
142
}
142
143
144
+ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID (
145
+ const std::vector<std::string> &var_names) const {
146
+ int64_t numel_sum = 0 ;
147
+ for (auto var_name : var_names) {
148
+ auto var_desc = all_vars_.at (var_name);
149
+ PADDLE_ENFORCE_NOT_NULL (var_desc);
150
+ auto dim = framework::make_ddim (var_desc->GetShape ());
151
+ int64_t numel = framework::product (dim);
152
+ PADDLE_ENFORCE_GT (numel, 0 );
153
+ numel_sum += numel;
154
+ }
155
+
156
+ auto smallest =
157
+ std::min_element (std::begin (balance_vars_), std::end (balance_vars_));
158
+ size_t dev_id =
159
+ static_cast <size_t >(std::distance (std::begin (balance_vars_), smallest));
160
+ balance_vars_[dev_id] += numel_sum;
161
+ return dev_id;
162
+ }
163
+
143
164
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build (
144
165
const ProgramDesc &program) const {
145
- std::unordered_map<std::string, VarDesc *> all_vars;
146
166
for (auto *var : program.Block (0 ).AllVars ()) {
147
- all_vars[ var->Name ()] = var;
167
+ all_vars_. emplace ( var->Name (), var) ;
148
168
}
149
169
150
170
auto graph = new SSAGraph ();
@@ -165,71 +185,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
165
185
bcast_var_name_set.resize (places_.size ());
166
186
167
187
size_t cur_device_id = 0 ;
168
- std::vector<int64_t > balance_grads (places_.size (), 0 );
169
-
170
- auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
171
- int64_t numel_all = 0 ;
172
- for (auto var_name : var_names) {
173
- auto var_desc = all_vars.at (var_name);
174
- PADDLE_ENFORCE_NOT_NULL (var_desc);
175
- auto dim = framework::make_ddim (var_desc->GetShape ());
176
- int64_t numel = framework::product (dim);
177
- PADDLE_ENFORCE_GT (numel, 0 );
178
- numel_all += numel;
179
- }
180
-
181
- auto smallest =
182
- std::min_element (std::begin (balance_grads), std::end (balance_grads));
183
- size_t dev_id =
184
- static_cast <size_t >(std::distance (std::begin (balance_grads), smallest));
185
- balance_grads[dev_id] += numel_all;
186
- return dev_id;
187
- };
188
-
189
188
bool is_forwarding = true ;
190
189
191
190
for (auto *op : program.Block (0 ).AllOps ()) {
192
191
if (boost::get<int >(
193
192
op->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ())) ==
194
193
static_cast <int >(OpRole::kRPC )) {
195
- // append rpc op if program is distributed trainer main program.
196
- // always use the first device
197
- if (op->Type () == " send_vars" ) {
198
- int op_dev_id = GetVarDeviceID (op->InputArgumentNames ()[0 ]);
199
- if (op_dev_id == -1 ) {
200
- op_dev_id = get_appropriate_dev (op->InputArgumentNames ());
201
- for (auto &varname : op->InputArgumentNames ()) {
202
- var_name_on_devices_.emplace (varname, op_dev_id);
203
- }
204
- }
205
- CreateRPCOp (&result, *op, op_dev_id);
206
- } else if (op->Type () == " recv" ) {
207
- int op_dev_id = get_appropriate_dev (op->OutputArgumentNames ());
208
- for (auto &varname : op->OutputArgumentNames ()) {
209
- var_name_on_devices_.emplace (varname, op_dev_id);
210
- }
211
- CreateRPCOp (&result, *op, op_dev_id);
212
- } else {
213
- // send_barrier and fetch_barrier op would run on device 0
214
- CreateRPCOp (&result, *op, 0 );
215
- }
194
+ CreateRPCOp (&result, *op);
216
195
} else if (IsDistTrainOp (*op, send_vars, recv_vars)) {
217
- if (op->Type () == " split_byref" ) {
218
- int op_dev_id = get_appropriate_dev (op->OutputArgumentNames ());
219
- for (auto &varname : op->OutputArgumentNames ()) {
220
- var_name_on_devices_.emplace (varname, op_dev_id);
221
- }
222
- CreateDistTrainOp (&result, *op, op_dev_id);
223
- } else if (op->Type () == " concat" ) {
224
- int op_dev_id = GetVarDeviceID (op->InputArgumentNames ()[0 ]);
225
- PADDLE_ENFORCE (op_dev_id != -1 ,
226
- " can not find right place to concatenate received var." );
227
- CreateDistTrainOp (&result, *op, op_dev_id);
228
- } else {
229
- PADDLE_ENFORCE (
230
- " the distribute training related op should be in [split_byref, "
231
- " concat]." );
232
- }
196
+ CreateDistTrainOp (&result, *op);
233
197
} else if (IsScaleLossOp (*op)) {
234
198
// user can customize loss@grad if not use_default_grad_scale_
235
199
if (strategy_.gradient_scale_ !=
@@ -267,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
267
231
268
232
switch (strategy_.reduce_ ) {
269
233
case BuildStrategy::ReduceStrategy::kReduce :
270
- cur_device_id = get_appropriate_dev ({g_name});
234
+ cur_device_id = GetAppropriateDeviceID ({g_name});
271
235
CreateReduceOp (&result, g_name, cur_device_id);
272
236
var_name_on_devices_.emplace (g_name, cur_device_id);
273
237
bcast_var_name_set[cur_device_id].emplace (p_name);
274
238
break ;
275
239
case BuildStrategy::ReduceStrategy::kAllReduce :
276
- if (IsSparseGradient (all_vars, g_name)) {
240
+ if (IsSparseGradient (g_name)) {
277
241
CreateReduceOp (&result, g_name, 0 );
278
242
CreateBroadcastOp (&result, g_name, 0 );
279
243
} else {
@@ -310,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
310
274
return std::unique_ptr<SSAGraph>(graph);
311
275
}
312
276
313
- bool MultiDevSSAGraphBuilder::IsSparseGradient (
314
- const std::unordered_map<std::string, VarDesc *> &all_vars,
315
- const std::string &og) const {
316
- PADDLE_ENFORCE (all_vars.count (og) != 0 );
317
- if (all_vars.at (og)->GetType () == proto::VarType::SELECTED_ROWS) {
277
+ bool MultiDevSSAGraphBuilder::IsSparseGradient (const std::string &og) const {
278
+ PADDLE_ENFORCE (all_vars_.count (og) != 0 );
279
+ if (all_vars_.at (og)->GetType () == proto::VarType::SELECTED_ROWS) {
318
280
return true ;
319
281
}
320
282
return false ;
@@ -498,18 +460,66 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
498
460
}
499
461
500
462
void MultiDevSSAGraphBuilder::CreateDistTrainOp (SSAGraph *result,
501
- const OpDesc &op,
502
- int place_id) const {
503
- CreateComputationalOp (result, op, place_id);
463
+ const OpDesc &op) const {
464
+ int op_dev_id = -1 ;
465
+ if (op.Type () == " split_byref" ) {
466
+ op_dev_id = GetVarDeviceID (op.InputArgumentNames ()[0 ]);
467
+ if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce ) {
468
+ op_dev_id = GetAppropriateDeviceID (op.InputArgumentNames ());
469
+ for (auto &varname : op.InputArgumentNames ()) {
470
+ var_name_on_devices_.emplace (varname, op_dev_id);
471
+ }
472
+ }
473
+ for (auto &varname : op.OutputArgumentNames ()) {
474
+ var_name_on_devices_.emplace (varname, op_dev_id);
475
+ }
476
+ } else if (op.Type () == " concat" ) {
477
+ op_dev_id = GetVarDeviceID (op.InputArgumentNames ()[0 ]);
478
+ } else {
479
+ PADDLE_ENFORCE (
480
+ " the distribute training related op should be in [split_byref, "
481
+ " concat]." );
482
+ }
483
+
484
+ PADDLE_ENFORCE (op_dev_id != -1 ,
485
+ " can not find right place for distributed op: %s" , op.Type ());
486
+
487
+ CreateComputationalOp (result, op, op_dev_id);
504
488
if (op.Type () == " concat" ) {
505
489
ConnectOp (result, result->ops_ .back ().get (), " fetch_barrier" );
506
490
}
507
491
}
508
492
509
- void MultiDevSSAGraphBuilder::CreateRPCOp (SSAGraph *result, const OpDesc &op,
510
- int device_id) const {
511
- result->ops_ .emplace_back (new RPCOpHandle (op, local_scopes_[device_id],
512
- op.Type (), places_[device_id]));
493
+ void MultiDevSSAGraphBuilder::CreateRPCOp (SSAGraph *result,
494
+ const OpDesc &op) const {
495
+ int op_dev_id = -1 ;
496
+ if (op.Type () == " send" ) {
497
+ op_dev_id = GetVarDeviceID (op.InputArgumentNames ()[0 ]);
498
+ // the variable name which contains .block means it was splited by
499
+ // split_byref op
500
+ // so that we can balance the variable blocks to all the pserver instances.
501
+ if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
502
+ op.InputArgumentNames ()[0 ].find (" .block" ) == std::string::npos) {
503
+ op_dev_id = GetAppropriateDeviceID (op.InputArgumentNames ());
504
+ for (auto &varname : op.InputArgumentNames ()) {
505
+ var_name_on_devices_.emplace (varname, op_dev_id);
506
+ }
507
+ }
508
+ } else if (op.Type () == " recv" ) {
509
+ op_dev_id = GetAppropriateDeviceID (op.OutputArgumentNames ());
510
+ for (auto &varname : op.OutputArgumentNames ()) {
511
+ var_name_on_devices_.emplace (varname, op_dev_id);
512
+ }
513
+ } else {
514
+ // send_barrier and fetch_barrier op can be scheduled on device 0
515
+ op_dev_id = 0 ;
516
+ }
517
+
518
+ PADDLE_ENFORCE (op_dev_id != -1 , " can not find the right place for rpc op: %s" ,
519
+ op.Type ());
520
+
521
+ result->ops_ .emplace_back (new RPCOpHandle (op, local_scopes_[op_dev_id],
522
+ op.Type (), places_[op_dev_id]));
513
523
514
524
if (op.Type () == " send_barrier" ) {
515
525
ConnectOp (result, result->ops_ .back ().get (), " send" );
@@ -525,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
525
535
" send, send_barrier. recv, fetch_barrier]" );
526
536
}
527
537
528
- // TODO(Yancey1989): schedule rpc op on different place may
529
- // increate throughput
530
- CreateOpHandleIOs (result, op, device_id);
538
+ CreateOpHandleIOs (result, op, op_dev_id);
531
539
}
532
540
533
541
bool MultiDevSSAGraphBuilder::IsScaleLossOp (const OpDesc &op) const {
0 commit comments