@@ -34,16 +34,22 @@ namespace paddle {
34
34
namespace framework {
35
35
namespace details {
36
36
37
+ static const char kLossVarName [] = " loss_var_name" ;
38
+ static const char kPlaces [] = " places" ;
39
+ static const char kParams [] = " params" ;
40
+ static const char kLocalScopes [] = " local_scopes" ;
41
+ static const char kStrategy [] = " strategy" ;
42
+
37
43
void MultiDevSSAGraphBuilder::Init () const {
38
- loss_var_name_ = Get<const std::string>(" loss_var_name " );
39
- places_ = Get<const std::vector<platform::Place>>(" places " );
40
- local_scopes_ = Get<const std::vector<Scope *>>(" local_scopes " );
41
- strategy_ = Get<const BuildStrategy>(" strategy " );
44
+ loss_var_name_ = Get<const std::string>(kLossVarName );
45
+ places_ = Get<const std::vector<platform::Place>>(kPlaces );
46
+ local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes );
47
+ strategy_ = Get<const BuildStrategy>(kStrategy );
42
48
#ifdef PADDLE_WITH_CUDA
43
49
nccl_ctxs_ = &Get<platform::NCCLContextMap>(" nccl_ctxs" );
44
50
#endif
45
51
46
- for (auto &p : Get<const std::unordered_set<std::string>>(" params " )) {
52
+ for (auto &p : Get<const std::unordered_set<std::string>>(kParams )) {
47
53
grad_names_.insert (GradVarName (p));
48
54
}
49
55
balance_vars_.resize (places_.size (), 0 );
@@ -58,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
58
64
ir::Node *node,
59
65
size_t place_id) const {
60
66
auto p = places_[place_id];
61
- auto *op_handle = result->Get <GraphOps>(" ops " ).back ().get ();
67
+ auto *op_handle = result->Get <GraphOps>(kGraphOps ).back ().get ();
62
68
op_handle->SetDeviceContext (p,
63
69
platform::DeviceContextPool::Instance ().Get (p));
64
70
@@ -225,7 +231,7 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
225
231
return sorted_ret;
226
232
}
227
233
228
- std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply (
234
+ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl (
229
235
std::unique_ptr<ir::Graph> graph) const {
230
236
Init ();
231
237
// Give the topology sort order and rebuild the graph structure.
@@ -241,10 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
241
247
std::unordered_set<std::string> og_has_been_broadcast;
242
248
243
249
// We cannot invoke resize. It is a bug of GCC 4.8
244
- result.Set (" vars " , new GraphVars (places_.size ()));
245
- result.Set (" dep_vars " , new GraphDepVars);
246
- result.Set (" ops " , new GraphOps);
247
- result.Set (" sharded_var_device " , new ShardedVarDevice);
250
+ result.Set (kGraphVars , new GraphVars (places_.size ()));
251
+ result.Set (kGraphDepVars , new GraphDepVars);
252
+ result.Set (kGraphOps , new GraphOps);
253
+ result.Set (kShardedVarDevice , new ShardedVarDevice);
248
254
249
255
// find send/recv vars so that we can place the distributed training
250
256
// realted op in the place 0
@@ -281,7 +287,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
281
287
if (op_dev_id != -1 ) { // This op only runs on one specific device.
282
288
CreateComputationalOp (&result, node, op_dev_id);
283
289
for (ir::Node *n : node->outputs ) {
284
- graph->Get <ShardedVarDevice>(" sharded_var_device " )
290
+ graph->Get <ShardedVarDevice>(kShardedVarDevice )
285
291
.emplace (n->Name (), op_dev_id);
286
292
}
287
293
} else {
@@ -319,7 +325,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
319
325
case BuildStrategy::ReduceStrategy::kReduce :
320
326
cur_device_id = GetAppropriateDeviceID ({g_name});
321
327
CreateReduceOp (&result, g_name, cur_device_id);
322
- graph->Get <ShardedVarDevice>(" sharded_var_device " )
328
+ graph->Get <ShardedVarDevice>(kShardedVarDevice )
323
329
.emplace (g_name, cur_device_id);
324
330
bcast_var_name_set[cur_device_id].emplace (p_name);
325
331
break ;
@@ -406,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
406
412
result->CreateEmptyNode (" broadcast" , ir::Node::Type::kOperation ),
407
413
local_scopes_, places_);
408
414
#endif
409
- result->Get <GraphOps>(" ops " ).emplace_back (op_handle);
415
+ result->Get <GraphOps>(kGraphOps ).emplace_back (op_handle);
410
416
411
417
auto *in =
412
- result->Get <GraphVars>(" vars " ).at (src_dev_id).at (p_name).back ().get ();
418
+ result->Get <GraphVars>(kGraphVars ).at (src_dev_id).at (p_name).back ().get ();
413
419
op_handle->AddInput (in);
414
420
415
421
for (size_t i = 0 ; i < places_.size (); ++i) {
416
422
auto &p = places_[i];
417
423
SetCommunicationContext (op_handle, p);
418
- auto &vars = result->Get <GraphVars>(" vars " ).at (i).at (p_name);
424
+ auto &vars = result->Get <GraphVars>(kGraphVars ).at (i).at (p_name);
419
425
auto *out_var = new VarHandle (
420
426
result->CreateEmptyNode (p_name, ir::Node::Type::kVariable ), vars.size (),
421
427
i, p_name, p);
@@ -427,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
427
433
void MultiDevSSAGraphBuilder::CreateComputationalOp (ir::Graph *result,
428
434
ir::Node *node,
429
435
int dev_id) const {
430
- result->Get <GraphOps>(" ops " ).emplace_back (
436
+ result->Get <GraphOps>(kGraphOps ).emplace_back (
431
437
new ComputationOpHandle (result->CreateOpNode (node->Op ()),
432
438
local_scopes_[dev_id], places_[dev_id]));
433
439
CreateOpHandleIOs (result, node, dev_id);
@@ -436,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
436
442
void MultiDevSSAGraphBuilder::InsertAllReduceOp (ir::Graph *result,
437
443
const std::string &og) const {
438
444
#ifdef PADDLE_WITH_CUDA
439
- result->Get <GraphOps>(" ops " ).emplace_back (new AllReduceOpHandle (
445
+ result->Get <GraphOps>(kGraphOps ).emplace_back (new AllReduceOpHandle (
440
446
result->CreateEmptyNode (" allreduce" , ir::Node::Type::kOperation ),
441
447
local_scopes_, places_, nccl_ctxs_));
442
448
#else
443
- result->Get <GraphOps>(" ops " ).emplace_back (new AllReduceOpHandle (
449
+ result->Get <GraphOps>(kGraphOps ).emplace_back (new AllReduceOpHandle (
444
450
result->CreateEmptyNode (" allreduce" , ir::Node::Type::kOperation ),
445
451
local_scopes_, places_));
446
452
#endif
447
- auto *op_handle = result->Get <GraphOps>(" ops " ).back ().get ();
453
+ auto *op_handle = result->Get <GraphOps>(kGraphOps ).back ().get ();
448
454
449
455
for (size_t i = 0 ; i < places_.size (); ++i) {
450
456
auto &p = places_[i];
451
457
SetCommunicationContext (op_handle, p);
452
- auto &vars = result->Get <GraphVars>(" vars " )[i][og];
458
+ auto &vars = result->Get <GraphVars>(kGraphVars )[i][og];
453
459
PADDLE_ENFORCE (!vars.empty ());
454
460
auto &prev_grad = vars.back ();
455
461
op_handle->AddInput (prev_grad.get ());
@@ -465,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
465
471
void MultiDevSSAGraphBuilder::InsertDataBalanceOp (
466
472
ir::Graph *result, const std::vector<std::string> &datas) const {
467
473
#ifdef PADDLE_WITH_CUDA
468
- result->Get <GraphOps>(" ops " ).emplace_back (new DataBalanceOpHandle (
474
+ result->Get <GraphOps>(kGraphOps ).emplace_back (new DataBalanceOpHandle (
469
475
result->CreateEmptyNode (" data_balance" , ir::Node::Type::kOperation ),
470
476
local_scopes_, places_, nccl_ctxs_));
471
477
#else
472
- result->Get <GraphOps>(" ops " ).emplace_back (new DataBalanceOpHandle (
478
+ result->Get <GraphOps>(kGraphOps ).emplace_back (new DataBalanceOpHandle (
473
479
result->CreateEmptyNode (" data_balance" , ir::Node::Type::kOperation ),
474
480
local_scopes_, places_));
475
481
#endif
476
- auto *op_handle = result->Get <GraphOps>(" ops " ).back ().get ();
482
+ auto *op_handle = result->Get <GraphOps>(kGraphOps ).back ().get ();
477
483
for (size_t i = 0 ; i < places_.size (); ++i) {
478
484
auto &p = places_[i];
479
485
SetCommunicationContext (op_handle, p);
480
486
for (const std::string &d_name : datas) {
481
- auto &vars = result->Get <GraphVars>(" vars " )[i][d_name];
487
+ auto &vars = result->Get <GraphVars>(kGraphVars )[i][d_name];
482
488
PADDLE_ENFORCE (!vars.empty ());
483
489
op_handle->AddInput (vars.back ().get ());
484
490
auto var = new VarHandle (
@@ -524,7 +530,7 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
524
530
525
531
int MultiDevSSAGraphBuilder::GetVarDeviceID (const ir::Graph &graph,
526
532
const std::string &varname) const {
527
- auto &sharded_var_device = graph.Get <ShardedVarDevice>(" sharded_var_device " );
533
+ auto &sharded_var_device = graph.Get <ShardedVarDevice>(kShardedVarDevice );
528
534
auto got = sharded_var_device.find (varname);
529
535
return got == sharded_var_device.end () ? -1 : got->second ;
530
536
}
@@ -544,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
544
550
result->CreateEmptyNode (" scale_loss_grad" , ir::Node::Type::kOperation ),
545
551
local_scopes_.size (), local_scopes_[i], places_[i],
546
552
communication_dev_ctx);
547
- result->Get <GraphOps>(" ops " ).emplace_back (op_handle);
553
+ result->Get <GraphOps>(kGraphOps ).emplace_back (op_handle);
548
554
549
555
// FIXME: Currently ScaleLossGradOp only use device_count as scale
550
556
// factor. So it does not depend on any other operators.
@@ -565,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
565
571
for (size_t scope_idx = 0 ; scope_idx < num_places; ++scope_idx) {
566
572
auto p = places_[scope_idx];
567
573
auto s = local_scopes_[scope_idx];
568
- result->Get <GraphOps>(" ops " ).emplace_back (
574
+ result->Get <GraphOps>(kGraphOps ).emplace_back (
569
575
new ComputationOpHandle (result->CreateOpNode (node->Op ()), s, p));
570
576
CreateOpHandleIOs (result, node, scope_idx);
571
577
}
@@ -575,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
575
581
const std::string &og,
576
582
int dst_dev_id) const {
577
583
#ifdef PADDLE_WITH_CUDA
578
- result->Get <GraphOps>(" ops " ).emplace_back (new ReduceOpHandle (
584
+ result->Get <GraphOps>(kGraphOps ).emplace_back (new ReduceOpHandle (
579
585
result->CreateEmptyNode (" reduce" , ir::Node::Type::kOperation ),
580
586
local_scopes_, places_, nccl_ctxs_));
581
587
#else
582
- result->Get <GraphOps>(" ops " ).emplace_back (new ReduceOpHandle (
588
+ result->Get <GraphOps>(kGraphOps ).emplace_back (new ReduceOpHandle (
583
589
result->CreateEmptyNode (" reduce" , ir::Node::Type::kOperation ),
584
590
local_scopes_, places_));
585
591
#endif
586
- auto *op_handle = result->Get <GraphOps>(" ops " ).back ().get ();
592
+ auto *op_handle = result->Get <GraphOps>(kGraphOps ).back ().get ();
587
593
588
594
for (size_t i = 0 ; i < places_.size (); ++i) {
589
595
auto &p = places_[i];
590
596
SetCommunicationContext (op_handle, p);
591
- auto &vars = result->Get <GraphVars>(" vars " )[i][og];
597
+ auto &vars = result->Get <GraphVars>(kGraphVars )[i][og];
592
598
PADDLE_ENFORCE (!vars.empty ());
593
599
auto &prev_grad = vars.back ();
594
600
op_handle->AddInput (prev_grad.get ());
595
601
}
596
- auto &vars = result->Get <GraphVars>(" vars " )[dst_dev_id][og];
602
+ auto &vars = result->Get <GraphVars>(kGraphVars )[dst_dev_id][og];
597
603
auto var =
598
604
new VarHandle (result->CreateEmptyNode (og, ir::Node::Type::kVariable ),
599
605
vars.size (), dst_dev_id, og, places_[dst_dev_id]);
@@ -606,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
606
612
// on it.
607
613
void MultiDevSSAGraphBuilder::ConnectOp (ir::Graph *result, OpHandleBase *op,
608
614
const std::string &prev_op_name) const {
609
- for (auto &prev_op : result->Get <GraphOps>(" ops " )) {
615
+ for (auto &prev_op : result->Get <GraphOps>(kGraphOps )) {
610
616
if (prev_op->Name () == prev_op_name) {
611
617
auto *dep_var = new DummyVarHandle (result->CreateControlDepVar ());
612
618
prev_op->AddOutput (dep_var);
613
- result->Get <GraphDepVars>(" dep_vars " ).emplace (dep_var);
619
+ result->Get <GraphDepVars>(kGraphDepVars ).emplace (dep_var);
614
620
op->AddInput (dep_var);
615
621
}
616
622
}
@@ -635,18 +641,18 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
635
641
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce ) {
636
642
op_dev_id = GetAppropriateDeviceID (input_var_names);
637
643
for (auto &varname : input_var_names) {
638
- result->Get <ShardedVarDevice>(" sharded_var_device " )
644
+ result->Get <ShardedVarDevice>(kShardedVarDevice )
639
645
.emplace (varname, op_dev_id);
640
646
}
641
647
}
642
648
for (auto &varname : output_var_names) {
643
- result->Get <ShardedVarDevice>(" sharded_var_device " )
649
+ result->Get <ShardedVarDevice>(kShardedVarDevice )
644
650
.emplace (varname, op_dev_id);
645
651
}
646
652
} else if (node->Op ()->Type () == " concat" ) {
647
653
op_dev_id = GetVarDeviceID (*result, input_var_names[0 ]);
648
654
for (auto &varname : output_var_names) {
649
- result->Get <ShardedVarDevice>(" sharded_var_device " )
655
+ result->Get <ShardedVarDevice>(kShardedVarDevice )
650
656
.emplace (varname, op_dev_id);
651
657
}
652
658
} else {
@@ -661,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
661
667
662
668
CreateComputationalOp (result, node, op_dev_id);
663
669
if (node->Op ()->Type () == " concat" ) {
664
- ConnectOp (result, result->Get <GraphOps>(" ops " ).back ().get (),
670
+ ConnectOp (result, result->Get <GraphOps>(kGraphOps ).back ().get (),
665
671
" fetch_barrier" );
666
672
}
667
673
}
@@ -687,7 +693,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
687
693
}
688
694
op_dev_id = GetAppropriateDeviceID (input_var_names);
689
695
for (auto &varname : input_var_names) {
690
- result->Get <ShardedVarDevice>(" sharded_var_device " )
696
+ result->Get <ShardedVarDevice>(kShardedVarDevice )
691
697
.emplace (varname, op_dev_id);
692
698
}
693
699
}
@@ -698,7 +704,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
698
704
}
699
705
op_dev_id = GetAppropriateDeviceID (output_var_names);
700
706
for (auto &varname : output_var_names) {
701
- result->Get <ShardedVarDevice>(" sharded_var_device " )
707
+ result->Get <ShardedVarDevice>(kShardedVarDevice )
702
708
.emplace (varname, op_dev_id);
703
709
}
704
710
} else {
@@ -709,17 +715,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
709
715
PADDLE_ENFORCE (op_dev_id != -1 , " can not find the right place for rpc op: %s" ,
710
716
node->Op ()->Type ());
711
717
712
- result->Get <GraphOps>(" ops " ).emplace_back (new RPCOpHandle (
718
+ result->Get <GraphOps>(kGraphOps ).emplace_back (new RPCOpHandle (
713
719
result->CreateOpNode (node->Op ()), *node->Op (), local_scopes_[op_dev_id],
714
720
node->Op ()->Type (), places_[op_dev_id]));
715
721
716
722
if (node->Op ()->Type () == " send_barrier" ) {
717
- ConnectOp (result, result->Get <GraphOps>(" ops " ).back ().get (), " send" );
723
+ ConnectOp (result, result->Get <GraphOps>(kGraphOps ).back ().get (), " send" );
718
724
} else if (node->Op ()->Type () == " recv" ) {
719
- ConnectOp (result, result->Get <GraphOps>(" ops " ).back ().get (),
725
+ ConnectOp (result, result->Get <GraphOps>(kGraphOps ).back ().get (),
720
726
" send_barrier" );
721
727
} else if (node->Op ()->Type () == " fetch_barrier" ) {
722
- ConnectOp (result, result->Get <GraphOps>(" ops " ).back ().get (), " recv" );
728
+ ConnectOp (result, result->Get <GraphOps>(kGraphOps ).back ().get (), " recv" );
723
729
} else if (node->Op ()->Type () == " send" ) {
724
730
// do nothing
725
731
} else {
@@ -743,4 +749,9 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
743
749
} // namespace paddle
744
750
745
751
REGISTER_PASS (multi_device_pass,
746
- paddle::framework::details::MultiDevSSAGraphBuilder);
752
+ paddle::framework::details::MultiDevSSAGraphBuilder)
753
+ .RequirePassAttr(paddle::framework::details::kLossVarName )
754
+ .RequirePassAttr(paddle::framework::details::kPlaces )
755
+ .RequirePassAttr(paddle::framework::details::kParams )
756
+ .RequirePassAttr(paddle::framework::details::kLocalScopes )
757
+ .RequirePassAttr(paddle::framework::details::kStrategy );
0 commit comments