@@ -142,7 +142,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
142
142
143
143
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build (
144
144
const ProgramDesc &program) const {
145
- VLOG (3 ) << " Building ...." ;
146
145
std::unordered_map<std::string, VarDesc *> all_vars;
147
146
for (auto *var : program.Block (0 ).AllVars ()) {
148
147
all_vars[var->Name ()] = var;
@@ -162,36 +161,32 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
162
161
auto send_vars = FindDistTrainSendVars (program);
163
162
auto recv_vars = FindDistTrainRecvVars (program);
164
163
165
- std::vector<std::unordered_set<std::string>> var_name_on_devices;
166
164
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
167
- var_name_on_devices.resize (places_.size ());
168
165
bcast_var_name_set.resize (places_.size ());
169
166
170
167
size_t cur_device_id = 0 ;
171
168
std::vector<int64_t > balance_grads (places_.size (), 0 );
172
169
173
- auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
174
- auto var_desc = all_vars.at (g_name);
175
- PADDLE_ENFORCE_NOT_NULL (var_desc);
176
- auto dim = framework::make_ddim (var_desc->GetShape ());
177
- int64_t numel = framework::product (dim);
178
- PADDLE_ENFORCE_GE (numel, 0 );
170
+ auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
171
+ int64_t numel_all = 0 ;
172
+ for (auto var_name : var_names) {
173
+ auto var_desc = all_vars.at (var_name);
174
+ PADDLE_ENFORCE_NOT_NULL (var_desc);
175
+ auto dim = framework::make_ddim (var_desc->GetShape ());
176
+ int64_t numel = framework::product (dim);
177
+ PADDLE_ENFORCE_GT (numel, 0 );
178
+ numel_all += numel;
179
+ }
180
+
179
181
auto smallest =
180
182
std::min_element (std::begin (balance_grads), std::end (balance_grads));
181
183
size_t dev_id =
182
184
static_cast <size_t >(std::distance (std::begin (balance_grads), smallest));
183
- balance_grads[dev_id] += numel ;
185
+ balance_grads[dev_id] += numel_all ;
184
186
return dev_id;
185
187
};
186
188
187
189
bool is_forwarding = true ;
188
- int rpc_op_device_id = 0 ;
189
- auto schedule_rpc_op = [&]() -> void {
190
- rpc_op_device_id++;
191
- if (rpc_op_device_id >= static_cast <int >(places_.size ())) {
192
- rpc_op_device_id = 0 ;
193
- }
194
- };
195
190
196
191
for (auto *op : program.Block (0 ).AllOps ()) {
197
192
if (boost::get<int >(
@@ -200,37 +195,40 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
200
195
// append rpc op if program is distributed trainer main program.
201
196
// always use the first device
202
197
if (op->Type () == " send_vars" ) {
203
- auto got = remote_vars_devices_.find (op->InputArgumentNames ()[0 ]);
204
- if (got == remote_vars_devices_.end ()) {
205
- schedule_rpc_op ();
206
- } else {
207
- rpc_op_device_id = got->second ;
198
+ int op_dev_id = GetVarDeviceID (op->InputArgumentNames ()[0 ]);
199
+ if (op_dev_id == -1 ) {
200
+ op_dev_id = get_appropriate_dev (op->InputArgumentNames ());
201
+ for (auto &varname : op->InputArgumentNames ()) {
202
+ var_name_on_devices_.emplace (varname, op_dev_id);
203
+ }
208
204
}
209
- CreateRPCOp (&result, *op, rpc_op_device_id );
205
+ CreateRPCOp (&result, *op, op_dev_id );
210
206
} else if (op->Type () == " recv" ) {
211
- schedule_rpc_op ( );
207
+ int op_dev_id = get_appropriate_dev (op-> OutputArgumentNames () );
212
208
for (auto &varname : op->OutputArgumentNames ()) {
213
- remote_vars_devices_. insert ({ varname, rpc_op_device_id} );
209
+ var_name_on_devices_. emplace ( varname, op_dev_id );
214
210
}
215
- CreateRPCOp (&result, *op, rpc_op_device_id );
211
+ CreateRPCOp (&result, *op, op_dev_id );
216
212
} else {
213
+ // send_barrier and fetch_barrier op would run on device 0
217
214
CreateRPCOp (&result, *op, 0 );
218
215
}
219
216
} else if (IsDistTrainOp (*op, send_vars, recv_vars)) {
220
217
if (op->Type () == " split_byref" ) {
221
- schedule_rpc_op ( );
218
+ int op_dev_id = get_appropriate_dev (op-> OutputArgumentNames () );
222
219
for (auto &varname : op->OutputArgumentNames ()) {
223
- remote_vars_devices_. insert ({ varname, rpc_op_device_id} );
220
+ var_name_on_devices_. emplace ( varname, op_dev_id );
224
221
}
225
- CreateDistTrainOp (&result, *op, rpc_op_device_id);
226
- }
227
- if (op->Type () == " concat" ) {
228
- auto got = remote_vars_devices_.find (op->InputArgumentNames ()[0 ]);
229
- PADDLE_ENFORCE (got != remote_vars_devices_.end (),
222
+ CreateDistTrainOp (&result, *op, op_dev_id);
223
+ } else if (op->Type () == " concat" ) {
224
+ int op_dev_id = GetVarDeviceID (op->InputArgumentNames ()[0 ]);
225
+ PADDLE_ENFORCE (op_dev_id != -1 ,
230
226
" can not find right place to concatenate received var." );
231
- CreateDistTrainOp (&result, *op, got-> second );
227
+ CreateDistTrainOp (&result, *op, op_dev_id );
232
228
} else {
233
- CreateDistTrainOp (&result, *op, 0 );
229
+ PADDLE_ENFORCE (
230
+ " the distribute training related op should be in [split_byref, "
231
+ " concat]." );
234
232
}
235
233
} else if (IsScaleLossOp (*op)) {
236
234
// user can customize loss@grad if not use_default_grad_scale_
@@ -240,13 +238,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
240
238
}
241
239
is_forwarding = false ;
242
240
} else {
243
- int op_dev_id = GetOpDeviceID (var_name_on_devices, *op);
241
+ int op_dev_id = GetOpDeviceID (*op);
244
242
if (op_dev_id == -1 ) { // var on all device
245
243
CreateComputationalOps (&result, *op, places_.size ());
246
244
} else {
247
245
CreateComputationalOp (&result, *op, op_dev_id);
248
246
for (auto &var_name : op->OutputArgumentNames ()) {
249
- var_name_on_devices[op_dev_id] .emplace (var_name);
247
+ var_name_on_devices_ .emplace (var_name, op_dev_id );
250
248
}
251
249
}
252
250
if (!is_forwarding && places_.size () > 1 ) {
@@ -269,9 +267,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
269
267
270
268
switch (strategy_.reduce_ ) {
271
269
case BuildStrategy::ReduceStrategy::kReduce :
272
- cur_device_id = get_appropriate_dev (g_name);
270
+ cur_device_id = get_appropriate_dev ({ g_name} );
273
271
CreateReduceOp (&result, g_name, cur_device_id);
274
- var_name_on_devices[cur_device_id] .emplace (g_name);
272
+ var_name_on_devices_ .emplace (g_name, cur_device_id );
275
273
bcast_var_name_set[cur_device_id].emplace (p_name);
276
274
break ;
277
275
case BuildStrategy::ReduceStrategy::kAllReduce :
@@ -402,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
402
400
return is_pg_once;
403
401
}
404
402
405
- int MultiDevSSAGraphBuilder::GetOpDeviceID (
406
- const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
407
- const OpDesc &op) const {
403
+ int MultiDevSSAGraphBuilder::GetOpDeviceID (const OpDesc &op) const {
408
404
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce ) {
409
405
return -1 ;
410
406
}
411
407
412
- int var_dev_id = -1 ;
413
- for (auto &var_name : op.InputArgumentNames ()) {
414
- if (var_dev_id != -1 ) break ;
415
- for (size_t i = 0 ; i < var_name_on_devices.size (); ++i) {
416
- if (var_name_on_devices[i].count (var_name)) {
417
- var_dev_id = static_cast <int >(i);
418
- break ;
419
- }
408
+ for (auto &varname : op.InputArgumentNames ()) {
409
+ int dev_id = GetVarDeviceID (varname);
410
+ if (dev_id != -1 ) {
411
+ return dev_id;
420
412
}
421
413
}
422
- return var_dev_id;
414
+ return -1 ;
415
+ }
416
+
417
+ int MultiDevSSAGraphBuilder::GetVarDeviceID (const std::string &varname) const {
418
+ auto got = var_name_on_devices_.find (varname);
419
+ return got == var_name_on_devices_.end () ? -1 : got->second ;
423
420
}
424
421
425
422
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp (SSAGraph *result) const {
0 commit comments