@@ -244,6 +244,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
244
244
result.Set (" vars" , new GraphVars (places_.size ()));
245
245
result.Set (" dep_vars" , new GraphDepVars);
246
246
result.Set (" ops" , new GraphOps);
247
+ result.Set (" sharded_var_device" , new ShardedVarDevice);
247
248
248
249
// find send/recv vars so that we can place the distributed training
249
250
// realted op in the place 0
@@ -276,11 +277,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
276
277
// the block.
277
278
is_forwarding = false ;
278
279
} else {
279
- int op_dev_id = GetOpDeviceID (node);
280
+ int op_dev_id = GetOpDeviceID (result, node);
280
281
if (op_dev_id != -1 ) { // This op only runs on one specific device.
281
282
CreateComputationalOp (&result, node, op_dev_id);
282
283
for (ir::Node *n : node->outputs ) {
283
- var_name_on_devices_.emplace (n->Name (), op_dev_id);
284
+ graph->Get <ShardedVarDevice>(" sharded_var_device" )
285
+ .emplace (n->Name (), op_dev_id);
284
286
}
285
287
} else {
286
288
// This op runs on all devices, and its output may have parameter's
@@ -317,7 +319,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
317
319
case BuildStrategy::ReduceStrategy::kReduce :
318
320
cur_device_id = GetAppropriateDeviceID ({g_name});
319
321
CreateReduceOp (&result, g_name, cur_device_id);
320
- var_name_on_devices_.emplace (g_name, cur_device_id);
322
+ graph->Get <ShardedVarDevice>(" sharded_var_device" )
323
+ .emplace (g_name, cur_device_id);
321
324
bcast_var_name_set[cur_device_id].emplace (p_name);
322
325
break ;
323
326
case BuildStrategy::ReduceStrategy::kAllReduce :
@@ -499,7 +502,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
499
502
return is_pg_once;
500
503
}
501
504
502
- int MultiDevSSAGraphBuilder::GetOpDeviceID (ir::Node *node) const {
505
+ int MultiDevSSAGraphBuilder::GetOpDeviceID (const ir::Graph &graph,
506
+ ir::Node *node) const {
503
507
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce ) {
504
508
return -1 ;
505
509
}
@@ -512,15 +516,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
512
516
node->Op ()->GetAttr (OpProtoAndCheckerMaker::OpRoleVarAttrName ()));
513
517
514
518
PADDLE_ENFORCE_EQ (param_grad.size (), 2U );
515
- int dev_id = GetVarDeviceID (param_grad[1 ]);
519
+ int dev_id = GetVarDeviceID (graph, param_grad[1 ]);
516
520
PADDLE_ENFORCE_NE (dev_id, -1 , " dev_id should not be -1.[%s, %s, %s]" ,
517
521
node->Op ()->Type (), param_grad[0 ], param_grad[1 ]);
518
522
return dev_id;
519
523
}
520
524
521
- int MultiDevSSAGraphBuilder::GetVarDeviceID (const std::string &varname) const {
522
- auto got = var_name_on_devices_.find (varname);
523
- return got == var_name_on_devices_.end () ? -1 : got->second ;
525
+ int MultiDevSSAGraphBuilder::GetVarDeviceID (const ir::Graph &graph,
526
+ const std::string &varname) const {
527
+ auto &sharded_var_device = graph.Get <ShardedVarDevice>(" sharded_var_device" );
528
+ auto got = sharded_var_device.find (varname);
529
+ return got == sharded_var_device.end () ? -1 : got->second ;
524
530
}
525
531
526
532
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp (ir::Graph *result) const {
@@ -625,20 +631,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
625
631
if (node->Op ()->Type () == " split_byref" ||
626
632
node->Op ()->Type () == " split_selected_rows" ) {
627
633
// TODO(paddle-dev): getting the first var is not safe.
628
- op_dev_id = GetVarDeviceID (input_var_names[0 ]);
634
+ op_dev_id = GetVarDeviceID (*result, input_var_names[0 ]);
629
635
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce ) {
630
636
op_dev_id = GetAppropriateDeviceID (input_var_names);
631
637
for (auto &varname : input_var_names) {
632
- var_name_on_devices_.emplace (varname, op_dev_id);
638
+ result->Get <ShardedVarDevice>(" sharded_var_device" )
639
+ .emplace (varname, op_dev_id);
633
640
}
634
641
}
635
642
for (auto &varname : output_var_names) {
636
- var_name_on_devices_.emplace (varname, op_dev_id);
643
+ result->Get <ShardedVarDevice>(" sharded_var_device" )
644
+ .emplace (varname, op_dev_id);
637
645
}
638
646
} else if (node->Op ()->Type () == " concat" ) {
639
- op_dev_id = GetVarDeviceID (input_var_names[0 ]);
647
+ op_dev_id = GetVarDeviceID (*result, input_var_names[0 ]);
640
648
for (auto &varname : output_var_names) {
641
- var_name_on_devices_.emplace (varname, op_dev_id);
649
+ result->Get <ShardedVarDevice>(" sharded_var_device" )
650
+ .emplace (varname, op_dev_id);
642
651
}
643
652
} else {
644
653
PADDLE_ENFORCE (
@@ -663,7 +672,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
663
672
int op_dev_id = -1 ;
664
673
if (node->Op ()->Type () == " send" ) {
665
674
// TODO(paddle-dev): getting the first var is not safe.
666
- op_dev_id = GetVarDeviceID (node->inputs [0 ]->Name ());
675
+ op_dev_id = GetVarDeviceID (*result, node->inputs [0 ]->Name ());
667
676
PADDLE_ENFORCE (!ir::IsControlDepVar (*node->inputs [0 ]),
668
677
" This hack no longer holds, please fix." );
669
678
// the variable name which contains .block means it was splited by
@@ -678,7 +687,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
678
687
}
679
688
op_dev_id = GetAppropriateDeviceID (input_var_names);
680
689
for (auto &varname : input_var_names) {
681
- var_name_on_devices_.emplace (varname, op_dev_id);
690
+ result->Get <ShardedVarDevice>(" sharded_var_device" )
691
+ .emplace (varname, op_dev_id);
682
692
}
683
693
}
684
694
} else if (node->Op ()->Type () == " recv" ) {
@@ -688,7 +698,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
688
698
}
689
699
op_dev_id = GetAppropriateDeviceID (output_var_names);
690
700
for (auto &varname : output_var_names) {
691
- var_name_on_devices_.emplace (varname, op_dev_id);
701
+ result->Get <ShardedVarDevice>(" sharded_var_device" )
702
+ .emplace (varname, op_dev_id);
692
703
}
693
704
} else {
694
705
// send_barrier and fetch_barrier op can be scheduled on device 0
@@ -730,3 +741,6 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
730
741
} // namespace details
731
742
} // namespace framework
732
743
} // namespace paddle
744
+
745
+ REGISTER_PASS (multi_device_pass,
746
+ paddle::framework::details::MultiDevSSAGraphBuilder);
0 commit comments