@@ -80,7 +80,14 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
80
80
}
81
81
82
82
for (ir::Node *output : node->outputs ) {
83
- CreateOpOutput (result, op_handle, output, p, place_id);
83
+ ir::Node *new_node = nullptr ;
84
+ if (output->Var ()) {
85
+ new_node = result->CreateVarNode (output->Var ());
86
+ } else {
87
+ new_node =
88
+ result->CreateEmptyNode (output->Name (), ir::Node::Type::kVariable );
89
+ }
90
+ CreateOpOutput (result, op_handle, new_node, p, place_id);
84
91
}
85
92
}
86
93
@@ -246,7 +253,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
246
253
if (node->Op ()->Type () == " read" && strategy_.enable_data_balance_ ) {
247
254
node->Op ()->SetAttr (" throw_eof_exp" , false );
248
255
CreateComputationalOps (&result, node.get (), places_.size ());
249
- // TODO(panyx0718 ): builder shouldn't depend on the out logic of
256
+ // TODO(paddle-dev ): builder shouldn't depend on the out logic of
250
257
// a specific op.
251
258
const auto &data_var_names = node->Op ()->Output (" Out" );
252
259
InsertDataBalanceOp (&result, data_var_names);
@@ -354,11 +361,13 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
354
361
const std::string &p_name,
355
362
size_t src_dev_id) const {
356
363
#ifdef PADDLE_WITH_CUDA
357
- auto *op_handle = new BroadcastOpHandle (result->CreateEmptyNode (" broadcast" ),
358
- local_scopes_, places_, nccl_ctxs_);
364
+ auto *op_handle = new BroadcastOpHandle (
365
+ result->CreateEmptyNode (" broadcast" , ir::Node::Type::kOperation ),
366
+ local_scopes_, places_, nccl_ctxs_);
359
367
#else
360
- auto *op_handle = new BroadcastOpHandle (result->CreateEmptyNode (" broadcast" ),
361
- local_scopes_, places_);
368
+ auto *op_handle = new BroadcastOpHandle (
369
+ result->CreateEmptyNode (" broadcast" , ir::Node::Type::kOperation ),
370
+ local_scopes_, places_);
362
371
#endif
363
372
result->Get <GraphOps>(" ops" ).emplace_back (op_handle);
364
373
@@ -370,8 +379,9 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
370
379
auto &p = places_[i];
371
380
SetCommunicationContext (op_handle, p);
372
381
auto &vars = result->Get <GraphVars>(" vars" ).at (i).at (p_name);
373
- auto *out_var = new VarHandle (result->CreateEmptyNode (p_name), vars.size (),
374
- i, p_name, p);
382
+ auto *out_var = new VarHandle (
383
+ result->CreateEmptyNode (p_name, ir::Node::Type::kVariable ), vars.size (),
384
+ i, p_name, p);
375
385
vars.emplace_back (out_var);
376
386
op_handle->AddOutput (out_var);
377
387
}
@@ -389,12 +399,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
389
399
void MultiDevSSAGraphBuilder::InsertAllReduceOp (Graph *result,
390
400
const std::string &og) const {
391
401
#ifdef PADDLE_WITH_CUDA
392
- result->Get <GraphOps>(" ops" ).emplace_back (
393
- new AllReduceOpHandle ( result->CreateEmptyNode (" allreduce" ), local_scopes_ ,
394
- places_, nccl_ctxs_));
402
+ result->Get <GraphOps>(" ops" ).emplace_back (new AllReduceOpHandle (
403
+ result->CreateEmptyNode (" allreduce" , ir::Node::Type:: kOperation ) ,
404
+ local_scopes_, places_, nccl_ctxs_));
395
405
#else
396
406
result->Get <GraphOps>(" ops" ).emplace_back (new AllReduceOpHandle (
397
- result->CreateEmptyNode (" allreduce" ), local_scopes_, places_));
407
+ result->CreateEmptyNode (" allreduce" , ir::Node::Type::kOperation ),
408
+ local_scopes_, places_));
398
409
#endif
399
410
auto *op_handle = result->Get <GraphOps>(" ops" ).back ().get ();
400
411
@@ -407,7 +418,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
407
418
op_handle->AddInput (prev_grad.get ());
408
419
409
420
auto var =
410
- new VarHandle (result->CreateEmptyNode (og), vars.size (), i, og, p);
421
+ new VarHandle (result->CreateEmptyNode (og, ir::Node::Type::kVariable ),
422
+ vars.size (), i, og, p);
411
423
vars.emplace_back (var);
412
424
op_handle->AddOutput (var);
413
425
}
@@ -416,12 +428,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
416
428
void MultiDevSSAGraphBuilder::InsertDataBalanceOp (
417
429
Graph *result, const std::vector<std::string> &datas) const {
418
430
#ifdef PADDLE_WITH_CUDA
419
- result->Get <GraphOps>(" ops" ).emplace_back (
420
- new DataBalanceOpHandle ( result->CreateEmptyNode (" data_balance" ),
421
- local_scopes_, places_, nccl_ctxs_));
431
+ result->Get <GraphOps>(" ops" ).emplace_back (new DataBalanceOpHandle (
432
+ result->CreateEmptyNode (" data_balance" , ir::Node::Type:: kOperation ),
433
+ local_scopes_, places_, nccl_ctxs_));
422
434
#else
423
435
result->Get <GraphOps>(" ops" ).emplace_back (new DataBalanceOpHandle (
424
- result->CreateEmptyNode (" data_balance" ), local_scopes_, places_));
436
+ result->CreateEmptyNode (" data_balance" , ir::Node::Type::kOperation ),
437
+ local_scopes_, places_));
425
438
#endif
426
439
auto *op_handle = result->Get <GraphOps>(" ops" ).back ().get ();
427
440
for (size_t i = 0 ; i < places_.size (); ++i) {
@@ -431,8 +444,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
431
444
auto &vars = result->Get <GraphVars>(" vars" )[i][d_name];
432
445
PADDLE_ENFORCE (!vars.empty ());
433
446
op_handle->AddInput (vars.back ().get ());
434
- auto var = new VarHandle (result->CreateEmptyNode (d_name), vars.size (), i,
435
- d_name, p);
447
+ auto var = new VarHandle (
448
+ result->CreateEmptyNode (d_name, ir::Node::Type::kVariable ),
449
+ vars.size (), i, d_name, p);
436
450
vars.emplace_back (var);
437
451
op_handle->AddOutput (var);
438
452
}
@@ -487,8 +501,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
487
501
platform::DeviceContextPool::Instance ().Get (platform::CPUPlace ());
488
502
#endif
489
503
auto *op_handle = new ScaleLossGradOpHandle (
490
- result->CreateEmptyNode (" scale_loss_grad" ), local_scopes_.size (),
491
- local_scopes_[i], places_[i], communication_dev_ctx);
504
+ result->CreateEmptyNode (" scale_loss_grad" , ir::Node::Type::kOperation ),
505
+ local_scopes_.size (), local_scopes_[i], places_[i],
506
+ communication_dev_ctx);
492
507
result->Get <GraphOps>(" ops" ).emplace_back (op_handle);
493
508
494
509
// FIXME: Currently ScaleLossGradOp only use device_count as scale
@@ -497,14 +512,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
497
512
// loss->pending_ops_.emplace_back(op_handle);
498
513
// op_handle->inputs_.emplace_back(loss);
499
514
500
- // TODO(panyx0718): GradVarName(loss_var_name_)
501
- const std::string grad_var_name = GradVarName (loss_var_name_);
502
- auto &vars = result->Get <GraphVars>(" vars" )[i][grad_var_name];
503
- size_t version = vars.size ();
504
- auto var = new VarHandle (result->CreateEmptyNode (grad_var_name), version, i,
505
- grad_var_name, places_[i]);
506
- vars.emplace_back (var);
507
- op_handle->AddOutput (var);
515
+ CreateOpOutput (result, op_handle,
516
+ result->CreateEmptyNode (GradVarName (loss_var_name_),
517
+ ir::Node::Type::kVariable ),
518
+ places_[i], i);
508
519
}
509
520
}
510
521
@@ -525,10 +536,12 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
525
536
int dst_dev_id) const {
526
537
#ifdef PADDLE_WITH_CUDA
527
538
result->Get <GraphOps>(" ops" ).emplace_back (new ReduceOpHandle (
528
- result->CreateEmptyNode (" reduce" ), local_scopes_, places_, nccl_ctxs_));
539
+ result->CreateEmptyNode (" reduce" , ir::Node::Type::kOperation ),
540
+ local_scopes_, places_, nccl_ctxs_));
529
541
#else
530
542
result->Get <GraphOps>(" ops" ).emplace_back (new ReduceOpHandle (
531
- result->CreateEmptyNode (" reduce" ), local_scopes_, places_));
543
+ result->CreateEmptyNode (" reduce" , ir::Node::Type::kOperation ),
544
+ local_scopes_, places_));
532
545
#endif
533
546
auto *op_handle = result->Get <GraphOps>(" ops" ).back ().get ();
534
547
@@ -541,8 +554,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
541
554
op_handle->AddInput (prev_grad.get ());
542
555
}
543
556
auto &vars = result->Get <GraphVars>(" vars" )[dst_dev_id][og];
544
- auto var = new VarHandle (result->CreateEmptyNode (og), vars.size (), dst_dev_id,
545
- og, places_[dst_dev_id]);
557
+ auto var =
558
+ new VarHandle (result->CreateEmptyNode (og, ir::Node::Type::kVariable ),
559
+ vars.size (), dst_dev_id, og, places_[dst_dev_id]);
546
560
vars.emplace_back (var);
547
561
op_handle->AddOutput (var);
548
562
return var;
@@ -554,7 +568,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
554
568
const std::string &prev_op_name) const {
555
569
for (auto &prev_op : result->Get <GraphOps>(" ops" )) {
556
570
if (prev_op->Name () == prev_op_name) {
557
- auto *dep_var = new DummyVarHandle (result->CreateEmptyNode (" dummy" ));
571
+ auto *dep_var = new DummyVarHandle (
572
+ result->CreateEmptyNode (" dummy" , ir::Node::Type::kVariable ));
558
573
prev_op->AddOutput (dep_var);
559
574
result->Get <GraphDepVars>(" dep_vars" ).emplace (dep_var);
560
575
op->AddInput (dep_var);
0 commit comments