Skip to content

Commit 10786a2

Browse files
committed
polish graph
1 parent 2fa8df1 commit 10786a2

12 files changed

+104
-113
lines changed

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ struct TestBroadcastOpHandle {
9696
}
9797
param_scopes_[input_scope_idx]->Var("input");
9898

99-
std::unique_ptr<ir::Node> n(new ir::Node());
99+
std::unique_ptr<ir::Node> n(new ir::Node("node0"));
100100
if (use_gpu_) {
101101
#ifdef PADDLE_WITH_CUDA
102102
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
@@ -114,15 +114,15 @@ struct TestBroadcastOpHandle {
114114
#endif
115115
}
116116

117-
std::unique_ptr<ir::Node> v(new ir::Node());
117+
std::unique_ptr<ir::Node> v(new ir::Node("node1"));
118118
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
119119
gpu_list_[input_scope_idx]);
120120
vars_.emplace_back(in_var_handle);
121121
op_handle_->AddInput(in_var_handle);
122122

123123
// add dummy var
124124

125-
std::unique_ptr<ir::Node> v2(new ir::Node());
125+
std::unique_ptr<ir::Node> v2(new ir::Node("node2"));
126126
vars_.emplace_back(new DummyVarHandle(v2.get()));
127127
DummyVarHandle* dummy_var_handle =
128128
static_cast<DummyVarHandle*>(vars_.back().get());
@@ -133,15 +133,15 @@ struct TestBroadcastOpHandle {
133133
if (!use_gpu_) {
134134
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
135135
}
136-
std::unique_ptr<ir::Node> v3(new ir::Node());
136+
std::unique_ptr<ir::Node> v3(new ir::Node("node3"));
137137
VarHandle* out_var_handle =
138138
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
139139
vars_.emplace_back(out_var_handle);
140140
op_handle_->AddOutput(out_var_handle);
141141
}
142142

143143
// add dummy var
144-
std::unique_ptr<ir::Node> v4(new ir::Node());
144+
std::unique_ptr<ir::Node> v4(new ir::Node("node4"));
145145
vars_.emplace_back(new DummyVarHandle(v4.get()));
146146
DummyVarHandle* out_dummy_var_handle =
147147
static_cast<DummyVarHandle*>(vars_.back().get());

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
namespace paddle {
2020
namespace framework {
2121
namespace details {
22-
ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc,
23-
Scope *scope, platform::Place place)
22+
ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
23+
platform::Place place)
2424
: OpHandleBase(node),
25-
op_(framework::OpRegistry::CreateOp(op_desc)),
25+
op_(framework::OpRegistry::CreateOp(*node->Op())),
2626
scope_(scope),
2727
place_(place) {}
2828

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ namespace framework {
2828
namespace details {
2929
struct ComputationOpHandle : public OpHandleBase {
3030
public:
31-
ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope,
32-
platform::Place place);
31+
ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place);
3332

3433
std::string Name() const override;
3534

paddle/fluid/framework/details/gather_op_handle_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,36 +82,36 @@ struct TestGatherOpHandle {
8282
}
8383
param_scopes_[input_scope_idx]->Var("out");
8484

85-
nodes.emplace_back(new ir::Node());
85+
nodes.emplace_back(new ir::Node("node"));
8686
op_handle_.reset(
8787
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
8888
// add input
8989
for (size_t j = 0; j < gpu_list_.size(); ++j) {
9090
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
91-
nodes.emplace_back(new ir::Node());
91+
nodes.emplace_back(new ir::Node("node1"));
9292
auto* in_var_handle =
9393
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
9494
vars_.emplace_back(in_var_handle);
9595
op_handle_->AddInput(in_var_handle);
9696
}
9797

9898
// add dummy var
99-
nodes.emplace_back(new ir::Node());
99+
nodes.emplace_back(new ir::Node("node2"));
100100
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
101101
DummyVarHandle* in_dummy_var_handle =
102102
static_cast<DummyVarHandle*>(vars_.back().get());
103103
in_dummy_var_handle->ClearGeneratedOp();
104104
op_handle_->AddInput(in_dummy_var_handle);
105105

106106
// add output
107-
nodes.emplace_back(new ir::Node());
107+
nodes.emplace_back(new ir::Node("node3"));
108108
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
109109
"out", gpu_list_[input_scope_idx]);
110110
vars_.emplace_back(out_var_handle);
111111
op_handle_->AddOutput(out_var_handle);
112112

113113
// add dummy var
114-
nodes.emplace_back(new ir::Node());
114+
nodes.emplace_back(new ir::Node("node4"));
115115
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
116116
DummyVarHandle* dummy_var_handle =
117117
static_cast<DummyVarHandle*>(vars_.back().get());

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
9090
// since parameters are all in block 0,
9191
// it's enough to only scan send ops in block 0
9292
for (auto &node : nodes) {
93-
if (!node->Op()) continue;
93+
if (node->NodeType() != ir::Node::Type::kOperation) continue;
9494
OpDesc *op = node->Op();
9595
// TODO(Yancey1989): use a graceful method to find send op,
9696
// instead of the the hard code string
@@ -108,7 +108,7 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
108108
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
109109
std::vector<std::string> recv_vars;
110110
for (auto &node : nodes) {
111-
if (!node->Op()) continue;
111+
if (node->NodeType() != ir::Node::Type::kOperation) continue;
112112
OpDesc *op = node->Op();
113113
// TODO(Yancey1989): use a graceful method to find recv op,
114114
// instead of the hard code string
@@ -149,10 +149,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
149149
std::vector<std::string> input_var_names;
150150
std::vector<std::string> output_var_names;
151151
for (ir::Node *input : node->inputs) {
152-
input_var_names.push_back(input->Var()->Name());
152+
input_var_names.push_back(input->Name());
153153
}
154154
for (ir::Node *output : node->outputs) {
155-
output_var_names.push_back(output->Var()->Name());
155+
output_var_names.push_back(output->Name());
156156
}
157157

158158
return checker(output_var_names, send_vars) ||
@@ -181,13 +181,13 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
181181

182182
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
183183
std::unique_ptr<Graph> graph) const {
184+
// Rebuild the graph structure.
184185
auto nodes = std::move(graph->nodes);
185186
graph->nodes.clear();
186-
LOG(ERROR) << "origin nodes count " << nodes.size();
187187

188188
for (auto &node : nodes) {
189-
if (node->Var()) {
190-
all_vars_.emplace(node->Var()->Name(), node->Var());
189+
if (node->NodeType() == ir::Node::Type::kVariable) {
190+
all_vars_.emplace(node->Name(), node->Var());
191191
}
192192
}
193193

@@ -212,7 +212,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
212212

213213
// TODO(panyx0718): FIXME: nodes should be sorted by "program" order.
214214
for (auto &node : nodes) {
215-
if (!node->Op()) continue;
215+
if (node->NodeType() != ir::Node::Type::kOperation) continue;
216216
if (boost::get<int>(
217217
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
218218
static_cast<int>(OpRole::kRPC)) {
@@ -235,7 +235,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
235235
if (op_dev_id != -1) { // This op only runs on one specific device.
236236
CreateComputationalOp(&result, node.get(), op_dev_id);
237237
for (ir::Node *n : node->outputs) {
238-
var_name_on_devices_.emplace(n->Var()->Name(), op_dev_id);
238+
var_name_on_devices_.emplace(n->Name(), op_dev_id);
239239
}
240240
} else {
241241
// This op runs on all devices, and its output may have parameter's
@@ -351,10 +351,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
351351
const std::string &p_name,
352352
size_t src_dev_id) const {
353353
#ifdef PADDLE_WITH_CUDA
354-
auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr),
354+
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
355355
local_scopes_, places_, nccl_ctxs_);
356356
#else
357-
auto *op_handle = new BroadcastOpHandle(result->CreateOpNode(nullptr),
357+
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
358358
local_scopes_, places_);
359359
#endif
360360
result->Get<GraphOps>("ops").emplace_back(op_handle);
@@ -367,8 +367,8 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
367367
auto &p = places_[i];
368368
SetCommunicationContext(op_handle, p);
369369
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
370-
auto *out_var =
371-
new VarHandle(result->CreateVarNode(p_name), vars.size(), i, p_name, p);
370+
auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(),
371+
i, p_name, p);
372372
vars.emplace_back(out_var);
373373
op_handle->AddOutput(out_var);
374374
}
@@ -378,19 +378,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
378378
ir::Node *node,
379379
int dev_id) const {
380380
result->Get<GraphOps>("ops").emplace_back(
381-
new ComputationOpHandle(result->CreateOpNode(node->Op()), *node->Op(),
381+
new ComputationOpHandle(result->CreateOpNode(node->Op()),
382382
local_scopes_[dev_id], places_[dev_id]));
383383
CreateOpHandleIOs(result, node, dev_id);
384384
}
385385

386386
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
387387
const std::string &og) const {
388388
#ifdef PADDLE_WITH_CUDA
389-
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
390-
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
389+
result->Get<GraphOps>("ops").emplace_back(
390+
new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_,
391+
places_, nccl_ctxs_));
391392
#else
392393
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
393-
result->CreateOpNode(nullptr), local_scopes_, places_));
394+
result->CreateEmptyNode("allreduce"), local_scopes_, places_));
394395
#endif
395396
auto *op_handle = result->Get<GraphOps>("ops").back().get();
396397

@@ -402,7 +403,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
402403
auto &prev_grad = vars.back();
403404
op_handle->AddInput(prev_grad.get());
404405

405-
auto var = new VarHandle(result->CreateVarNode(og), vars.size(), i, og, p);
406+
auto var =
407+
new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p);
406408
vars.emplace_back(var);
407409
op_handle->AddOutput(var);
408410
}
@@ -411,11 +413,12 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
411413
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
412414
Graph *result, const std::vector<std::string> &datas) const {
413415
#ifdef PADDLE_WITH_CUDA
414-
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
415-
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
416+
result->Get<GraphOps>("ops").emplace_back(
417+
new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"),
418+
local_scopes_, places_, nccl_ctxs_));
416419
#else
417420
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
418-
result->CreateOpNode(nullptr), local_scopes_, places_));
421+
result->CreateEmptyNode("data_balance"), local_scopes_, places_));
419422
#endif
420423
auto *op_handle = result->Get<GraphOps>("ops").back().get();
421424
for (size_t i = 0; i < places_.size(); ++i) {
@@ -425,7 +428,7 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
425428
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
426429
PADDLE_ENFORCE(!vars.empty());
427430
op_handle->AddInput(vars.back().get());
428-
auto var = new VarHandle(result->CreateVarNode(d_name), vars.size(), i,
431+
auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i,
429432
d_name, p);
430433
vars.emplace_back(var);
431434
op_handle->AddOutput(var);
@@ -455,12 +458,12 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
455458
return -1;
456459
}
457460
auto param_grad = boost::get<std::vector<std::string>>(
458-
node->Op()->.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
461+
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
459462

460463
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
461464
int dev_id = GetVarDeviceID(param_grad[1]);
462-
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(),
463-
param_grad[0]);
465+
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]",
466+
node->Op()->Type(), param_grad[0]);
464467
return dev_id;
465468
}
466469

@@ -481,8 +484,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
481484
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
482485
#endif
483486
auto *op_handle = new ScaleLossGradOpHandle(
484-
result->CreateOpNode(nullptr), local_scopes_.size(), local_scopes_[i],
485-
places_[i], communication_dev_ctx);
487+
result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(),
488+
local_scopes_[i], places_[i], communication_dev_ctx);
486489
result->Get<GraphOps>("ops").emplace_back(op_handle);
487490

488491
// FIXME: Currently ScaleLossGradOp only use device_count as scale
@@ -495,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
495498
const std::string grad_var_name = GradVarName(loss_var_name_);
496499
auto &vars = result->Get<GraphVars>("vars")[i][grad_var_name];
497500
size_t version = vars.size();
498-
auto var = new VarHandle(result->CreateVarNode(grad_var_name), version, i,
501+
auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i,
499502
grad_var_name, places_[i]);
500503
vars.emplace_back(var);
501504
op_handle->AddOutput(var);
@@ -508,8 +511,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
508511
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
509512
auto p = places_[scope_idx];
510513
auto s = local_scopes_[scope_idx];
511-
result->Get<GraphOps>("ops").emplace_back(new ComputationOpHandle(
512-
result->CreateOpNode(node->Op()), *node->Op(), s, p));
514+
result->Get<GraphOps>("ops").emplace_back(
515+
new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p));
513516
CreateOpHandleIOs(result, node, scope_idx);
514517
}
515518
}
@@ -519,10 +522,10 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
519522
int dst_dev_id) const {
520523
#ifdef PADDLE_WITH_CUDA
521524
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
522-
result->CreateOpNode(nullptr), local_scopes_, places_, nccl_ctxs_));
525+
result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_));
523526
#else
524527
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
525-
result->CreateOpNode(nullptr), local_scopes_, places_));
528+
result->CreateEmptyNode("reduce"), local_scopes_, places_));
526529
#endif
527530
auto *op_handle = result->Get<GraphOps>("ops").back().get();
528531

@@ -535,7 +538,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
535538
op_handle->AddInput(prev_grad.get());
536539
}
537540
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
538-
auto var = new VarHandle(result->CreateVarNode(og), vars.size(), dst_dev_id,
541+
auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id,
539542
og, places_[dst_dev_id]);
540543
vars.emplace_back(var);
541544
op_handle->AddOutput(var);
@@ -548,7 +551,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
548551
const std::string &prev_op_name) const {
549552
for (auto &prev_op : result->Get<GraphOps>("ops")) {
550553
if (prev_op->Name() == prev_op_name) {
551-
auto *dep_var = new DummyVarHandle(result->CreateVarNode("dummy"));
554+
auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy"));
552555
prev_op->AddOutput(dep_var);
553556
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
554557
op->AddInput(dep_var);
@@ -562,10 +565,10 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
562565
std::vector<std::string> input_var_names;
563566
std::vector<std::string> output_var_names;
564567
for (ir::Node *input : node->inputs) {
565-
input_var_names.push_back(input->Var()->Name());
568+
input_var_names.push_back(input->Name());
566569
}
567570
for (ir::Node *output : node->outputs) {
568-
output_var_names.push_back(output->Var()->Name());
571+
output_var_names.push_back(output->Name());
569572
}
570573

571574
if (node->Op()->Type() == "split_byref" ||
@@ -606,16 +609,16 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
606609
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
607610
int op_dev_id = -1;
608611
if (node->Op()->Type() == "send") {
609-
op_dev_id = GetVarDeviceID(node->inputs[0]->Var()->Name());
612+
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
610613
// the variable name which contains .block means it was splited by
611614
// split_byref op
612615
// so that we can balance the variable blocks to all the pserver
613616
// instances.
614617
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
615-
node->inputs[0]->Var()->Name().find(".block") == std::string::npos) {
618+
node->inputs[0]->Name().find(".block") == std::string::npos) {
616619
std::vector<std::string> input_var_names;
617620
for (ir::Node *n : node->inputs) {
618-
input_var_names.push_back(n->Var()->Name());
621+
input_var_names.push_back(n->Name());
619622
}
620623
op_dev_id = GetAppropriateDeviceID(input_var_names);
621624
for (auto &varname : input_var_names) {
@@ -625,7 +628,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
625628
} else if (node->Op()->Type() == "recv") {
626629
std::vector<std::string> output_var_names;
627630
for (ir::Node *n : node->outputs) {
628-
output_var_names.push_back(n->Var()->Name());
631+
output_var_names.push_back(n->Name());
629632
}
630633
op_dev_id = GetAppropriateDeviceID(output_var_names);
631634
for (auto &varname : output_var_names) {

0 commit comments

Comments
 (0)