@@ -90,7 +90,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
90
90
// since parameters are all in block 0,
91
91
// it's enough to only scan send ops in block 0
92
92
for (auto &node : nodes) {
93
- if (! node->Op () ) continue ;
93
+ if (node->NodeType () != ir::Node::Type:: kOperation ) continue ;
94
94
OpDesc *op = node->Op ();
95
95
// TODO(Yancey1989): use a graceful method to find send op,
96
96
// instead of the the hard code string
@@ -108,7 +108,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
108
108
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
109
109
std::vector<std::string> recv_vars;
110
110
for (auto &node : nodes) {
111
- if (! node->Op () ) continue ;
111
+ if (node->NodeType () != ir::Node::Type:: kOperation ) continue ;
112
112
OpDesc *op = node->Op ();
113
113
// TODO(Yancey1989): use a graceful method to find recv op,
114
114
// instead of the hard code string
@@ -149,10 +149,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
149
149
std::vector<std::string> input_var_names;
150
150
std::vector<std::string> output_var_names;
151
151
for (ir::Node *input : node->inputs ) {
152
- input_var_names.push_back (input->Var ()-> Name ());
152
+ input_var_names.push_back (input->Name ());
153
153
}
154
154
for (ir::Node *output : node->outputs ) {
155
- output_var_names.push_back (output->Var ()-> Name ());
155
+ output_var_names.push_back (output->Name ());
156
156
}
157
157
158
158
return checker (output_var_names, send_vars) ||
@@ -181,13 +181,13 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
181
181
182
182
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply (
183
183
std::unique_ptr<Graph> graph) const {
184
+ // Rebuild the graph structure.
184
185
auto nodes = std::move (graph->nodes );
185
186
graph->nodes .clear ();
186
- LOG (ERROR) << " origin nodes count " << nodes.size ();
187
187
188
188
for (auto &node : nodes) {
189
- if (node->Var () ) {
190
- all_vars_.emplace (node->Var ()-> Name (), node->Var ());
189
+ if (node->NodeType () == ir::Node::Type:: kVariable ) {
190
+ all_vars_.emplace (node->Name (), node->Var ());
191
191
}
192
192
}
193
193
@@ -212,7 +212,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
212
212
213
213
// TODO(panyx0718): FIXME: nodes should be sorted by "program" order.
214
214
for (auto &node : nodes) {
215
- if (! node->Op () ) continue ;
215
+ if (node->NodeType () != ir::Node::Type:: kOperation ) continue ;
216
216
if (boost::get<int >(
217
217
node->Op ()->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ())) ==
218
218
static_cast <int >(OpRole::kRPC )) {
@@ -235,7 +235,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
235
235
if (op_dev_id != -1 ) { // This op only runs on one specific device.
236
236
CreateComputationalOp (&result, node.get (), op_dev_id);
237
237
for (ir::Node *n : node->outputs ) {
238
- var_name_on_devices_.emplace (n->Var ()-> Name (), op_dev_id);
238
+ var_name_on_devices_.emplace (n->Name (), op_dev_id);
239
239
}
240
240
} else {
241
241
// This op runs on all devices, and its output may have parameter's
@@ -351,10 +351,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
351
351
const std::string &p_name,
352
352
size_t src_dev_id) const {
353
353
#ifdef PADDLE_WITH_CUDA
354
- auto *op_handle = new BroadcastOpHandle (result->CreateOpNode ( nullptr ),
354
+ auto *op_handle = new BroadcastOpHandle (result->CreateEmptyNode ( " broadcast " ),
355
355
local_scopes_, places_, nccl_ctxs_);
356
356
#else
357
- auto *op_handle = new BroadcastOpHandle (result->CreateOpNode ( nullptr ),
357
+ auto *op_handle = new BroadcastOpHandle (result->CreateEmptyNode ( " broadcast " ),
358
358
local_scopes_, places_);
359
359
#endif
360
360
result->Get <GraphOps>(" ops" ).emplace_back (op_handle);
@@ -367,8 +367,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
367
367
auto &p = places_[i];
368
368
SetCommunicationContext (op_handle, p);
369
369
auto &vars = result->Get <GraphVars>(" vars" ).at (i).at (p_name);
370
- auto *out_var =
371
- new VarHandle (result-> CreateVarNode (p_name), vars. size (), i, p_name, p);
370
+ auto *out_var = new VarHandle (result-> CreateEmptyNode (p_name), vars. size (),
371
+ i, p_name, p);
372
372
vars.emplace_back (out_var);
373
373
op_handle->AddOutput (out_var);
374
374
}
@@ -378,19 +378,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
378
378
ir::Node *node,
379
379
int dev_id) const {
380
380
result->Get <GraphOps>(" ops" ).emplace_back (
381
- new ComputationOpHandle (result->CreateOpNode (node->Op ()), *node-> Op (),
381
+ new ComputationOpHandle (result->CreateOpNode (node->Op ()),
382
382
local_scopes_[dev_id], places_[dev_id]));
383
383
CreateOpHandleIOs (result, node, dev_id);
384
384
}
385
385
386
386
void MultiDevSSAGraphBuilder::InsertAllReduceOp (Graph *result,
387
387
const std::string &og) const {
388
388
#ifdef PADDLE_WITH_CUDA
389
- result->Get <GraphOps>(" ops" ).emplace_back (new AllReduceOpHandle (
390
- result->CreateOpNode (nullptr ), local_scopes_, places_, nccl_ctxs_));
389
+ result->Get <GraphOps>(" ops" ).emplace_back (
390
+ new AllReduceOpHandle (result->CreateEmptyNode (" allreduce" ), local_scopes_,
391
+ places_, nccl_ctxs_));
391
392
#else
392
393
result->Get <GraphOps>(" ops" ).emplace_back (new AllReduceOpHandle (
393
- result->CreateOpNode ( nullptr ), local_scopes_, places_));
394
+ result->CreateEmptyNode ( " allreduce " ), local_scopes_, places_));
394
395
#endif
395
396
auto *op_handle = result->Get <GraphOps>(" ops" ).back ().get ();
396
397
@@ -402,7 +403,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
402
403
auto &prev_grad = vars.back ();
403
404
op_handle->AddInput (prev_grad.get ());
404
405
405
- auto var = new VarHandle (result->CreateVarNode (og), vars.size (), i, og, p);
406
+ auto var =
407
+ new VarHandle (result->CreateEmptyNode (og), vars.size (), i, og, p);
406
408
vars.emplace_back (var);
407
409
op_handle->AddOutput (var);
408
410
}
@@ -411,11 +413,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
411
413
void MultiDevSSAGraphBuilder::InsertDataBalanceOp (
412
414
Graph *result, const std::vector<std::string> &datas) const {
413
415
#ifdef PADDLE_WITH_CUDA
414
- result->Get <GraphOps>(" ops" ).emplace_back (new DataBalanceOpHandle (
415
- result->CreateOpNode (nullptr ), local_scopes_, places_, nccl_ctxs_));
416
+ result->Get <GraphOps>(" ops" ).emplace_back (
417
+ new DataBalanceOpHandle (result->CreateEmptyNode (" data_balance" ),
418
+ local_scopes_, places_, nccl_ctxs_));
416
419
#else
417
420
result->Get <GraphOps>(" ops" ).emplace_back (new DataBalanceOpHandle (
418
- result->CreateOpNode ( nullptr ), local_scopes_, places_));
421
+ result->CreateEmptyNode ( " data_balance " ), local_scopes_, places_));
419
422
#endif
420
423
auto *op_handle = result->Get <GraphOps>(" ops" ).back ().get ();
421
424
for (size_t i = 0 ; i < places_.size (); ++i) {
@@ -425,7 +428,7 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
425
428
auto &vars = result->Get <GraphVars>(" vars" )[i][d_name];
426
429
PADDLE_ENFORCE (!vars.empty ());
427
430
op_handle->AddInput (vars.back ().get ());
428
- auto var = new VarHandle (result->CreateVarNode (d_name), vars.size (), i,
431
+ auto var = new VarHandle (result->CreateEmptyNode (d_name), vars.size (), i,
429
432
d_name, p);
430
433
vars.emplace_back (var);
431
434
op_handle->AddOutput (var);
@@ -455,12 +458,12 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
455
458
return -1 ;
456
459
}
457
460
auto param_grad = boost::get<std::vector<std::string>>(
458
- node->Op ()->. GetAttr (OpProtoAndCheckerMaker::OpRoleVarAttrName ()));
461
+ node->Op ()->GetAttr (OpProtoAndCheckerMaker::OpRoleVarAttrName ()));
459
462
460
463
PADDLE_ENFORCE_EQ (param_grad.size (), 2U );
461
464
int dev_id = GetVarDeviceID (param_grad[1 ]);
462
- PADDLE_ENFORCE_NE (dev_id, -1 , " dev_id should not be -1.[%s, %s]" , op. Type (),
463
- param_grad[0 ]);
465
+ PADDLE_ENFORCE_NE (dev_id, -1 , " dev_id should not be -1.[%s, %s]" ,
466
+ node-> Op ()-> Type (), param_grad[0 ]);
464
467
return dev_id;
465
468
}
466
469
@@ -481,8 +484,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
481
484
platform::DeviceContextPool::Instance ().Get (platform::CPUPlace ());
482
485
#endif
483
486
auto *op_handle = new ScaleLossGradOpHandle (
484
- result->CreateOpNode ( nullptr ), local_scopes_.size (), local_scopes_[i] ,
485
- places_[i], communication_dev_ctx);
487
+ result->CreateEmptyNode ( " scale_loss_grad " ), local_scopes_.size (),
488
+ local_scopes_[i], places_[i], communication_dev_ctx);
486
489
result->Get <GraphOps>(" ops" ).emplace_back (op_handle);
487
490
488
491
// FIXME: Currently ScaleLossGradOp only use device_count as scale
@@ -495,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
495
498
const std::string grad_var_name = GradVarName (loss_var_name_);
496
499
auto &vars = result->Get <GraphVars>(" vars" )[i][grad_var_name];
497
500
size_t version = vars.size ();
498
- auto var = new VarHandle (result->CreateVarNode (grad_var_name), version, i,
501
+ auto var = new VarHandle (result->CreateEmptyNode (grad_var_name), version, i,
499
502
grad_var_name, places_[i]);
500
503
vars.emplace_back (var);
501
504
op_handle->AddOutput (var);
@@ -508,8 +511,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
508
511
for (size_t scope_idx = 0 ; scope_idx < num_places; ++scope_idx) {
509
512
auto p = places_[scope_idx];
510
513
auto s = local_scopes_[scope_idx];
511
- result->Get <GraphOps>(" ops" ).emplace_back (new ComputationOpHandle (
512
- result->CreateOpNode (node->Op ()), *node-> Op ( ), s, p));
514
+ result->Get <GraphOps>(" ops" ).emplace_back (
515
+ new ComputationOpHandle ( result->CreateOpNode (node->Op ()), s, p));
513
516
CreateOpHandleIOs (result, node, scope_idx);
514
517
}
515
518
}
@@ -519,10 +522,10 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
519
522
int dst_dev_id) const {
520
523
#ifdef PADDLE_WITH_CUDA
521
524
result->Get <GraphOps>(" ops" ).emplace_back (new ReduceOpHandle (
522
- result->CreateOpNode ( nullptr ), local_scopes_, places_, nccl_ctxs_));
525
+ result->CreateEmptyNode ( " reduce " ), local_scopes_, places_, nccl_ctxs_));
523
526
#else
524
527
result->Get <GraphOps>(" ops" ).emplace_back (new ReduceOpHandle (
525
- result->CreateOpNode ( nullptr ), local_scopes_, places_));
528
+ result->CreateEmptyNode ( " reduce " ), local_scopes_, places_));
526
529
#endif
527
530
auto *op_handle = result->Get <GraphOps>(" ops" ).back ().get ();
528
531
@@ -535,7 +538,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
535
538
op_handle->AddInput (prev_grad.get ());
536
539
}
537
540
auto &vars = result->Get <GraphVars>(" vars" )[dst_dev_id][og];
538
- auto var = new VarHandle (result->CreateVarNode (og), vars.size (), dst_dev_id,
541
+ auto var = new VarHandle (result->CreateEmptyNode (og), vars.size (), dst_dev_id,
539
542
og, places_[dst_dev_id]);
540
543
vars.emplace_back (var);
541
544
op_handle->AddOutput (var);
@@ -548,7 +551,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
548
551
const std::string &prev_op_name) const {
549
552
for (auto &prev_op : result->Get <GraphOps>(" ops" )) {
550
553
if (prev_op->Name () == prev_op_name) {
551
- auto *dep_var = new DummyVarHandle (result->CreateVarNode (" dummy" ));
554
+ auto *dep_var = new DummyVarHandle (result->CreateEmptyNode (" dummy" ));
552
555
prev_op->AddOutput (dep_var);
553
556
result->Get <GraphDepVars>(" dep_vars" ).emplace (dep_var);
554
557
op->AddInput (dep_var);
@@ -562,10 +565,10 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
562
565
std::vector<std::string> input_var_names;
563
566
std::vector<std::string> output_var_names;
564
567
for (ir::Node *input : node->inputs ) {
565
- input_var_names.push_back (input->Var ()-> Name ());
568
+ input_var_names.push_back (input->Name ());
566
569
}
567
570
for (ir::Node *output : node->outputs ) {
568
- output_var_names.push_back (output->Var ()-> Name ());
571
+ output_var_names.push_back (output->Name ());
569
572
}
570
573
571
574
if (node->Op ()->Type () == " split_byref" ||
@@ -606,16 +609,16 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
606
609
void MultiDevSSAGraphBuilder::CreateRPCOp (Graph *result, ir::Node *node) const {
607
610
int op_dev_id = -1 ;
608
611
if (node->Op ()->Type () == " send" ) {
609
- op_dev_id = GetVarDeviceID (node->inputs [0 ]->Var ()-> Name ());
612
+ op_dev_id = GetVarDeviceID (node->inputs [0 ]->Name ());
610
613
// the variable name which contains .block means it was splited by
611
614
// split_byref op
612
615
// so that we can balance the variable blocks to all the pserver
613
616
// instances.
614
617
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
615
- node->inputs [0 ]->Var ()-> Name ().find (" .block" ) == std::string::npos) {
618
+ node->inputs [0 ]->Name ().find (" .block" ) == std::string::npos) {
616
619
std::vector<std::string> input_var_names;
617
620
for (ir::Node *n : node->inputs ) {
618
- input_var_names.push_back (n->Var ()-> Name ());
621
+ input_var_names.push_back (n->Name ());
619
622
}
620
623
op_dev_id = GetAppropriateDeviceID (input_var_names);
621
624
for (auto &varname : input_var_names) {
@@ -625,7 +628,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
625
628
} else if (node->Op ()->Type () == " recv" ) {
626
629
std::vector<std::string> output_var_names;
627
630
for (ir::Node *n : node->outputs ) {
628
- output_var_names.push_back (n->Var ()-> Name ());
631
+ output_var_names.push_back (n->Name ());
629
632
}
630
633
op_dev_id = GetAppropriateDeviceID (output_var_names);
631
634
for (auto &varname : output_var_names) {
0 commit comments