Skip to content

Commit ff5a7b6

Browse files
committed
polish
1 parent a891708 commit ff5a7b6

File tree

9 files changed

+85
-65
lines changed

9 files changed

+85
-65
lines changed

paddle/fluid/framework/details/broadcast_op_handle_test.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ struct TestBroadcastOpHandle {
9696
}
9797
param_scopes_[input_scope_idx]->Var("input");
9898

99-
std::unique_ptr<ir::Node> n(new ir::Node("node0"));
99+
std::unique_ptr<ir::Node> n(
100+
new ir::Node("node0", ir::Node::Type::kOperation));
100101
if (use_gpu_) {
101102
#ifdef PADDLE_WITH_CUDA
102103
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
@@ -114,15 +115,17 @@ struct TestBroadcastOpHandle {
114115
#endif
115116
}
116117

117-
std::unique_ptr<ir::Node> v(new ir::Node("node1"));
118+
std::unique_ptr<ir::Node> v(
119+
new ir::Node("node1", ir::Node::Type::kVariable));
118120
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
119121
gpu_list_[input_scope_idx]);
120122
vars_.emplace_back(in_var_handle);
121123
op_handle_->AddInput(in_var_handle);
122124

123125
// add dummy var
124126

125-
std::unique_ptr<ir::Node> v2(new ir::Node("node2"));
127+
std::unique_ptr<ir::Node> v2(
128+
new ir::Node("node2", ir::Node::Type::kVariable));
126129
vars_.emplace_back(new DummyVarHandle(v2.get()));
127130
DummyVarHandle* dummy_var_handle =
128131
static_cast<DummyVarHandle*>(vars_.back().get());
@@ -133,15 +136,17 @@ struct TestBroadcastOpHandle {
133136
if (!use_gpu_) {
134137
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
135138
}
136-
std::unique_ptr<ir::Node> v3(new ir::Node("node3"));
139+
std::unique_ptr<ir::Node> v3(
140+
new ir::Node("node3", ir::Node::Type::kVariable));
137141
VarHandle* out_var_handle =
138142
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
139143
vars_.emplace_back(out_var_handle);
140144
op_handle_->AddOutput(out_var_handle);
141145
}
142146

143147
// add dummy var
144-
std::unique_ptr<ir::Node> v4(new ir::Node("node4"));
148+
std::unique_ptr<ir::Node> v4(
149+
new ir::Node("node4", ir::Node::Type::kVariable));
145150
vars_.emplace_back(new DummyVarHandle(v4.get()));
146151
DummyVarHandle* out_dummy_var_handle =
147152
static_cast<DummyVarHandle*>(vars_.back().get());

paddle/fluid/framework/details/gather_op_handle_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,36 +82,36 @@ struct TestGatherOpHandle {
8282
}
8383
param_scopes_[input_scope_idx]->Var("out");
8484

85-
nodes.emplace_back(new ir::Node("node"));
85+
nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation));
8686
op_handle_.reset(
8787
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_));
8888
// add input
8989
for (size_t j = 0; j < gpu_list_.size(); ++j) {
9090
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
91-
nodes.emplace_back(new ir::Node("node1"));
91+
nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable));
9292
auto* in_var_handle =
9393
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]);
9494
vars_.emplace_back(in_var_handle);
9595
op_handle_->AddInput(in_var_handle);
9696
}
9797

9898
// add dummy var
99-
nodes.emplace_back(new ir::Node("node2"));
99+
nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable));
100100
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
101101
DummyVarHandle* in_dummy_var_handle =
102102
static_cast<DummyVarHandle*>(vars_.back().get());
103103
in_dummy_var_handle->ClearGeneratedOp();
104104
op_handle_->AddInput(in_dummy_var_handle);
105105

106106
// add output
107-
nodes.emplace_back(new ir::Node("node3"));
107+
nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable));
108108
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx,
109109
"out", gpu_list_[input_scope_idx]);
110110
vars_.emplace_back(out_var_handle);
111111
op_handle_->AddOutput(out_var_handle);
112112

113113
// add dummy var
114-
nodes.emplace_back(new ir::Node("node4"));
114+
nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable));
115115
vars_.emplace_back(new DummyVarHandle(nodes.back().get()));
116116
DummyVarHandle* dummy_var_handle =
117117
static_cast<DummyVarHandle*>(vars_.back().get());

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,14 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
8080
}
8181

8282
for (ir::Node *output : node->outputs) {
83-
CreateOpOutput(result, op_handle, output, p, place_id);
83+
ir::Node *new_node = nullptr;
84+
if (output->Var()) {
85+
new_node = result->CreateVarNode(output->Var());
86+
} else {
87+
new_node =
88+
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
89+
}
90+
CreateOpOutput(result, op_handle, new_node, p, place_id);
8491
}
8592
}
8693

@@ -246,7 +253,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
246253
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
247254
node->Op()->SetAttr("throw_eof_exp", false);
248255
CreateComputationalOps(&result, node.get(), places_.size());
249-
// TODO(panyx0718): builder shouldn't depend on the out logic of
256+
// TODO(paddle-dev): builder shouldn't depend on the out logic of
250257
// a specific op.
251258
const auto &data_var_names = node->Op()->Output("Out");
252259
InsertDataBalanceOp(&result, data_var_names);
@@ -354,11 +361,13 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
354361
const std::string &p_name,
355362
size_t src_dev_id) const {
356363
#ifdef PADDLE_WITH_CUDA
357-
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
358-
local_scopes_, places_, nccl_ctxs_);
364+
auto *op_handle = new BroadcastOpHandle(
365+
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
366+
local_scopes_, places_, nccl_ctxs_);
359367
#else
360-
auto *op_handle = new BroadcastOpHandle(result->CreateEmptyNode("broadcast"),
361-
local_scopes_, places_);
368+
auto *op_handle = new BroadcastOpHandle(
369+
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
370+
local_scopes_, places_);
362371
#endif
363372
result->Get<GraphOps>("ops").emplace_back(op_handle);
364373

@@ -370,8 +379,9 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
370379
auto &p = places_[i];
371380
SetCommunicationContext(op_handle, p);
372381
auto &vars = result->Get<GraphVars>("vars").at(i).at(p_name);
373-
auto *out_var = new VarHandle(result->CreateEmptyNode(p_name), vars.size(),
374-
i, p_name, p);
382+
auto *out_var = new VarHandle(
383+
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(),
384+
i, p_name, p);
375385
vars.emplace_back(out_var);
376386
op_handle->AddOutput(out_var);
377387
}
@@ -389,12 +399,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
389399
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
390400
const std::string &og) const {
391401
#ifdef PADDLE_WITH_CUDA
392-
result->Get<GraphOps>("ops").emplace_back(
393-
new AllReduceOpHandle(result->CreateEmptyNode("allreduce"), local_scopes_,
394-
places_, nccl_ctxs_));
402+
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
403+
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
404+
local_scopes_, places_, nccl_ctxs_));
395405
#else
396406
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
397-
result->CreateEmptyNode("allreduce"), local_scopes_, places_));
407+
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
408+
local_scopes_, places_));
398409
#endif
399410
auto *op_handle = result->Get<GraphOps>("ops").back().get();
400411

@@ -407,7 +418,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
407418
op_handle->AddInput(prev_grad.get());
408419

409420
auto var =
410-
new VarHandle(result->CreateEmptyNode(og), vars.size(), i, og, p);
421+
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
422+
vars.size(), i, og, p);
411423
vars.emplace_back(var);
412424
op_handle->AddOutput(var);
413425
}
@@ -416,12 +428,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
416428
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
417429
Graph *result, const std::vector<std::string> &datas) const {
418430
#ifdef PADDLE_WITH_CUDA
419-
result->Get<GraphOps>("ops").emplace_back(
420-
new DataBalanceOpHandle(result->CreateEmptyNode("data_balance"),
421-
local_scopes_, places_, nccl_ctxs_));
431+
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
432+
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
433+
local_scopes_, places_, nccl_ctxs_));
422434
#else
423435
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
424-
result->CreateEmptyNode("data_balance"), local_scopes_, places_));
436+
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
437+
local_scopes_, places_));
425438
#endif
426439
auto *op_handle = result->Get<GraphOps>("ops").back().get();
427440
for (size_t i = 0; i < places_.size(); ++i) {
@@ -431,8 +444,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
431444
auto &vars = result->Get<GraphVars>("vars")[i][d_name];
432445
PADDLE_ENFORCE(!vars.empty());
433446
op_handle->AddInput(vars.back().get());
434-
auto var = new VarHandle(result->CreateEmptyNode(d_name), vars.size(), i,
435-
d_name, p);
447+
auto var = new VarHandle(
448+
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
449+
vars.size(), i, d_name, p);
436450
vars.emplace_back(var);
437451
op_handle->AddOutput(var);
438452
}
@@ -487,8 +501,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
487501
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
488502
#endif
489503
auto *op_handle = new ScaleLossGradOpHandle(
490-
result->CreateEmptyNode("scale_loss_grad"), local_scopes_.size(),
491-
local_scopes_[i], places_[i], communication_dev_ctx);
504+
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
505+
local_scopes_.size(), local_scopes_[i], places_[i],
506+
communication_dev_ctx);
492507
result->Get<GraphOps>("ops").emplace_back(op_handle);
493508

494509
// FIXME: Currently ScaleLossGradOp only use device_count as scale
@@ -497,14 +512,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
497512
// loss->pending_ops_.emplace_back(op_handle);
498513
// op_handle->inputs_.emplace_back(loss);
499514

500-
// TODO(panyx0718): GradVarName(loss_var_name_)
501-
const std::string grad_var_name = GradVarName(loss_var_name_);
502-
auto &vars = result->Get<GraphVars>("vars")[i][grad_var_name];
503-
size_t version = vars.size();
504-
auto var = new VarHandle(result->CreateEmptyNode(grad_var_name), version, i,
505-
grad_var_name, places_[i]);
506-
vars.emplace_back(var);
507-
op_handle->AddOutput(var);
515+
CreateOpOutput(result, op_handle,
516+
result->CreateEmptyNode(GradVarName(loss_var_name_),
517+
ir::Node::Type::kVariable),
518+
places_[i], i);
508519
}
509520
}
510521

@@ -525,10 +536,12 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
525536
int dst_dev_id) const {
526537
#ifdef PADDLE_WITH_CUDA
527538
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
528-
result->CreateEmptyNode("reduce"), local_scopes_, places_, nccl_ctxs_));
539+
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
540+
local_scopes_, places_, nccl_ctxs_));
529541
#else
530542
result->Get<GraphOps>("ops").emplace_back(new ReduceOpHandle(
531-
result->CreateEmptyNode("reduce"), local_scopes_, places_));
543+
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
544+
local_scopes_, places_));
532545
#endif
533546
auto *op_handle = result->Get<GraphOps>("ops").back().get();
534547

@@ -541,8 +554,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
541554
op_handle->AddInput(prev_grad.get());
542555
}
543556
auto &vars = result->Get<GraphVars>("vars")[dst_dev_id][og];
544-
auto var = new VarHandle(result->CreateEmptyNode(og), vars.size(), dst_dev_id,
545-
og, places_[dst_dev_id]);
557+
auto var =
558+
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
559+
vars.size(), dst_dev_id, og, places_[dst_dev_id]);
546560
vars.emplace_back(var);
547561
op_handle->AddOutput(var);
548562
return var;
@@ -554,7 +568,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
554568
const std::string &prev_op_name) const {
555569
for (auto &prev_op : result->Get<GraphOps>("ops")) {
556570
if (prev_op->Name() == prev_op_name) {
557-
auto *dep_var = new DummyVarHandle(result->CreateEmptyNode("dummy"));
571+
auto *dep_var = new DummyVarHandle(
572+
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
558573
prev_op->AddOutput(dep_var);
559574
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
560575
op->AddInput(dep_var);

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
3737
continue;
3838
}
3939

40-
auto *dep_var = new DummyVarHandle(graph->CreateEmptyNode("dummy"));
40+
auto *dep_var = new DummyVarHandle(
41+
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
4142
read_op->AddOutput(dep_var);
4243
write_op->AddInput(dep_var);
4344
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
@@ -54,12 +55,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
5455
auto &var_holder = var_holders[node->Name()];
5556
VarHandle *var = nullptr;
5657
if (var_holder.empty()) {
57-
if (node->NodeType() == ir::Node::Type::kVariable) {
58+
if (node->Var()) {
5859
var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset,
5960
node->Name(), place);
6061
} else {
61-
var = new VarHandle(graph->CreateEmptyNode(node->Name()), 0, place_offset,
62-
node->Name(), place);
62+
var = new VarHandle(
63+
graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0,
64+
place_offset, node->Name(), place);
6365
}
6466
var_holder.emplace_back(var);
6567
} else {
@@ -69,13 +71,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
6971
}
7072

7173
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
72-
ir::Node *node,
74+
ir::Node *new_node,
7375
const platform::Place &place,
7476
size_t place_offset) {
75-
auto &vars = graph->Get<GraphVars>("vars")[place_offset][node->Name()];
77+
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
7678
size_t version = vars.size();
77-
auto var = new VarHandle(graph->CreateVarNode(node->Var()), version,
78-
place_offset, node->Name(), place);
79+
auto var =
80+
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
7981
vars.emplace_back(var);
8082
op_handle->AddOutput(var);
8183
}
@@ -85,7 +87,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
8587
if (!op->Outputs().empty()) {
8688
continue;
8789
}
88-
auto *dummy_leaf = new DummyVarHandle(graph->CreateEmptyNode("dummy"));
90+
auto *dummy_leaf = new DummyVarHandle(
91+
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
8992
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
9093
op->AddOutput(dummy_leaf);
9194
}

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class SSAGraphBuilder : public ir::Pass {
7373
// Add an output variable (each_var_name, place, place_offset) to op_handle,
7474
// which belongs to graph
7575
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
76-
ir::Node *node, const platform::Place &place,
76+
ir::Node *new_node, const platform::Place &place,
7777
size_t place_offset);
7878

7979
static void AddOutputToLeafOps(Graph *graph);

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
173173
auto &var_name = fetch_tensors[i];
174174
auto &vars = fetched_vars.at(var_name);
175175

176-
temp_nodes->emplace_back(new ir::Node("fetch"));
176+
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
177177
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
178178
&local_scopes_);
179179
fetch_ops->emplace_back(op);
@@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
186186
op->AddInput(var);
187187
}
188188

189-
temp_nodes->emplace_back(new ir::Node("fetch"));
189+
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
190190
auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get());
191191
op->AddOutput(fetch_dummy);
192192
fetch_dependencies->emplace(fetch_dummy);

paddle/fluid/framework/ir/graph.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
4141
// TODO(paddle-dev): Seems some assumption doesn't hold?
4242
LOG(ERROR) << op->Type()
4343
<< " input var not in all_var list: " << each_var_name;
44-
var = CreateEmptyNode(each_var_name);
44+
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
4545
var_nodes[each_var_name] = var;
4646
}
4747
node->inputs.push_back(var);

paddle/fluid/framework/ir/graph.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ class Graph {
6767
// TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph.
6868
// node should either be a executable kOperation or a kVariable. kNone
6969
// node is a temporary solution.
70-
ir::Node* CreateEmptyNode(const std::string& name) {
71-
nodes.emplace_back(new ir::Node(name));
70+
ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) {
71+
nodes.emplace_back(new ir::Node(name, type));
7272
return nodes.back().get();
7373
}
7474

paddle/fluid/framework/ir/node.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,9 @@ namespace ir {
2626

2727
class Node {
2828
public:
29-
enum class Type { kNone, kOperation, kVariable };
30-
explicit Node(const std::string& name)
31-
: name_(name),
32-
var_desc_(nullptr),
33-
op_desc_(nullptr),
34-
type_(Type::kNone) {}
29+
enum class Type { kOperation, kVariable };
30+
explicit Node(const std::string& name, Type type)
31+
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
3532

3633
explicit Node(VarDesc* var_desc)
3734
: name_(var_desc->Name()),

0 commit comments

Comments
 (0)