Skip to content

Commit 21a4542

Browse files
committed
polish and test
1 parent 2782e71 commit 21a4542

File tree

11 files changed

+166
-48
lines changed

11 files changed

+166
-48
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
170170
const std::vector<std::string> &var_names) const {
171171
int64_t numel_sum = 0;
172172
for (auto var_name : var_names) {
173+
if (all_vars_.find(var_name) == all_vars_.end()) continue;
173174
auto var_desc = all_vars_.at(var_name);
174175
PADDLE_ENFORCE_NOT_NULL(var_desc);
175176
auto dim = framework::make_ddim(var_desc->GetShape());
@@ -271,6 +272,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
271272
// user can customize loss@grad if not use_default_grad_scale_
272273
if (strategy_.gradient_scale_ !=
273274
BuildStrategy::GradientScaleStrategy::kCustomized) {
275+
// TODO(paddle-dev): Why is there no input for this op_handle?
274276
CreateScaleLossGradOp(&result);
275277
}
276278
// This assumes the backward generating code will ensure IsScaleLossOp
@@ -288,6 +290,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
288290
} else {
289291
// This op runs on all devices, and its output may have parameter's
290292
// gradients.
293+
// TODO(paddle-dev): Why is so special about "read" op?
291294
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
292295
node->Op()->SetAttr("throw_eof_exp", false);
293296
CreateComputationalOps(&result, node, places_.size());
@@ -363,6 +366,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
363366
* Only variables should be the leaves of graph.
364367
*/
365368
AddOutputToLeafOps(&result);
369+
PADDLE_ENFORCE(!ir::HasCircle(result));
366370
return graph;
367371
}
368372

@@ -620,6 +624,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
620624

621625
if (node->Op()->Type() == "split_byref" ||
622626
node->Op()->Type() == "split_selected_rows") {
627+
// TODO(paddle-dev): getting the first var is not safe.
623628
op_dev_id = GetVarDeviceID(input_var_names[0]);
624629
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
625630
op_dev_id = GetAppropriateDeviceID(input_var_names);
@@ -657,7 +662,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
657662
ir::Node *node) const {
658663
int op_dev_id = -1;
659664
if (node->Op()->Type() == "send") {
665+
// TODO(paddle-dev): getting the first var is not safe.
660666
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
667+
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
668+
"This hack no longer holds, please fix.");
661669
// the variable name which contains .block means it was splited by
662670
// split_byref op
663671
// so that we can balance the variable blocks to all the pserver

paddle/fluid/framework/details/rpc_op_handle.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/rpc_op_handle.h"
16+
#include "paddle/fluid/framework/ir/graph.h"
1617

1718
namespace paddle {
1819
namespace framework {
@@ -33,8 +34,7 @@ void RPCOpHandle::RunImpl() {
3334
for (auto *in : inputs_) {
3435
auto &p = static_cast<VarHandle *>(in)->place_;
3536
// FIXME(Yancey1989): need a better solution instead of use DebugString()
36-
if (in->Node()->Name().find(ir::Node::kControlDepVarName) !=
37-
std::string::npos) { // HACK
37+
if (ir::IsControlDepVar(*in->Node())) { // HACK
3838
continue;
3939
}
4040
if (in->GeneratedOp()) {

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,6 @@
1717
namespace paddle {
1818
namespace framework {
1919
namespace details {
20-
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
21-
for (auto &var_map : graph->Get<GraphVars>("vars")) {
22-
for (auto &name_pair : var_map) {
23-
if (name_pair.second.size() <= 1) {
24-
continue;
25-
}
26-
auto it_new = name_pair.second.rbegin();
27-
auto it_old = name_pair.second.rbegin();
28-
++it_old;
29-
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
30-
OpHandleBase *write_op = (*it_new)->GeneratedOp();
31-
const auto &read_ops = (*it_old)->PendingOps();
32-
33-
for (auto *read_op : read_ops) {
34-
// Manually add a dependency var from read_op to write_op;
35-
if (read_op == write_op) {
36-
// Read Write is the same op.
37-
continue;
38-
}
39-
40-
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
41-
read_op->AddOutput(dep_var);
42-
write_op->AddInput(dep_var);
43-
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
44-
}
45-
}
46-
}
47-
}
48-
}
49-
5020
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
5121
ir::Graph *graph, ir::Node *node, const platform::Place &place,
5222
size_t place_offset) {

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@ class SSAGraphBuilder : public ir::Pass {
5757
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
5858

5959
protected:
60-
/**
61-
* We only handle write after read(WAR), since it should not have a write
62-
* after write in program. If there are write after write operators, we need
63-
* prune them.
64-
*
65-
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
66-
*/
67-
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
68-
6960
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
7061
const platform::Place &place,
7162
size_t place_offset);

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ cc_library(node SRCS node.cc DEPS proto_desc)
22
cc_library(graph SRCS graph.cc DEPS node)
33
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
44
cc_library(pass SRCS pass.cc DEPS graph node)
5-
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry)
5+
cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry)
6+
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry)

paddle/fluid/framework/ir/graph.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
107107
}
108108
}
109109
}
110+
111+
bool IsControlDepVar(const ir::Node &var) {
112+
return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos;
113+
}
110114
} // namespace ir
111115
} // namespace framework
112116
} // namespace paddle

paddle/fluid/framework/ir/graph.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,34 @@ class Graph {
5757

5858
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
5959

60+
// Create a normal variable with non-null VarDesc.
6061
ir::Node *CreateVarNode(VarDesc *var_desc) {
6162
return AddNode(new ir::Node(var_desc));
6263
}
6364

65+
// Create a normal runnable operator with OpDesc.
6466
ir::Node *CreateOpNode(OpDesc *op_desc) {
6567
return AddNode(new ir::Node(op_desc));
6668
}
6769

70+
// Create a control dependency var that connects 2 operations. The
71+
// var doesn't hold any data. Other than that, it's no different from
72+
// other var, considering dependency analysis.
6873
ir::Node *CreateControlDepVar() {
69-
// TODO(panyx0718): control var name should be unique.
74+
// TODO(panyx0718): control var name should be really unique.
7075
const std::string name = string::Sprintf(
7176
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
7277
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
7378
}
7479

80+
// A more free style way of creating a graph node. Mostly use for test
81+
// or "copy" from another node. Avoid using it if possible.
7582
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
7683
return AddNode(new ir::Node(name, type));
7784
}
7885

86+
// Clear all node information of the graph and return the ownership of the
87+
// nodes.
7988
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
8089
std::vector<std::unique_ptr<ir::Node>> ret;
8190
for (auto &n : nodes_) {
@@ -108,6 +117,8 @@ class Graph {
108117
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
109118
std::unordered_set<ir::Node *> node_set_;
110119
};
120+
121+
bool IsControlDepVar(const ir::Node &var);
111122
} // namespace ir
112123
} // namespace framework
113124
} // namespace paddle

paddle/fluid/framework/ir/graph_helper.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,9 @@ bool HasCircleHelper(
5959
in_trace->erase(node);
6060
return false;
6161
}
62-
} // namespace
63-
64-
bool HasCircle(const Graph &graph) {
65-
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
66-
BuildOperationAdjList(graph);
6762

63+
bool HasCircleInternal(
64+
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
6865
std::unordered_set<ir::Node *> visited;
6966
std::unordered_set<ir::Node *> in_trace;
7067
for (auto &adj : adj_list) {
@@ -74,10 +71,16 @@ bool HasCircle(const Graph &graph) {
7471
}
7572
return false;
7673
}
74+
} // namespace
75+
76+
bool HasCircle(const Graph &graph) {
77+
return HasCircleInternal(BuildOperationAdjList(graph));
78+
}
7779

7880
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
7981
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
8082
BuildOperationAdjList(graph);
83+
PADDLE_ENFORCE(!HasCircleInternal(adj_list));
8184
std::unordered_set<ir::Node *> visited;
8285
std::vector<ir::Node *> ret;
8386
for (auto adj : adj_list) {

paddle/fluid/framework/ir/graph_helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,14 @@ limitations under the License. */
2424
namespace paddle {
2525
namespace framework {
2626
namespace ir {
27+
// Test if the graph contains circle.
2728
bool HasCircle(const Graph &graph);
2829

30+
// Topology Sort the operations in the graph from inputs to outputs.
31+
// `graph` cannot contain circle.
2932
std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
3033

34+
// Build an adjacency list of operations for the `graph`.
3135
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
3236
const Graph &graph);
3337

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/graph.h"
16+
#include <string>
17+
#include "gtest/gtest.h"
18+
#include "paddle/fluid/framework/ir/graph_helper.h"
19+
#include "paddle/fluid/framework/program_desc.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
25+
void BuildCircleGraph(Graph* g) {
26+
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
27+
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
28+
29+
o1->outputs.push_back(v1);
30+
o1->inputs.push_back(v1);
31+
v1->inputs.push_back(o1);
32+
v1->outputs.push_back(o1);
33+
}
34+
35+
void BuildCircleGraph2(Graph* g) {
36+
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
37+
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
38+
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
39+
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
40+
41+
o1->outputs.push_back(v1);
42+
o2->inputs.push_back(v1);
43+
v1->inputs.push_back(o1);
44+
v1->outputs.push_back(o2);
45+
46+
o2->outputs.push_back(v2);
47+
o1->inputs.push_back(v2);
48+
v2->inputs.push_back(o2);
49+
v2->outputs.push_back(o1);
50+
}
51+
52+
void BuildNoCircleGraph(Graph* g) {
53+
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
54+
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
55+
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
56+
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
57+
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
58+
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
59+
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
60+
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
61+
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
62+
63+
// o1->v1->o2
64+
o1->outputs.push_back(v1);
65+
o2->inputs.push_back(v1);
66+
v1->inputs.push_back(o1);
67+
v1->outputs.push_back(o2);
68+
// o2->v2->o3
69+
// o2->v2->o4
70+
o2->outputs.push_back(v2);
71+
o3->inputs.push_back(v2);
72+
o4->inputs.push_back(v2);
73+
v2->inputs.push_back(o2);
74+
v2->outputs.push_back(o3);
75+
v2->outputs.push_back(o4);
76+
// o2->v3->o5
77+
o2->outputs.push_back(v3);
78+
o5->inputs.push_back(v3);
79+
v3->inputs.push_back(o2);
80+
v3->outputs.push_back(o5);
81+
// o3-v4->o5
82+
o3->outputs.push_back(v4);
83+
o5->inputs.push_back(v4);
84+
v4->inputs.push_back(o3);
85+
v4->outputs.push_back(o5);
86+
}
87+
88+
TEST(GraphHelperTest, Basic) {
89+
ProgramDesc prog;
90+
91+
Graph g(prog);
92+
BuildCircleGraph(&g);
93+
ASSERT_TRUE(HasCircle(g));
94+
95+
Graph g2(prog);
96+
BuildCircleGraph2(&g2);
97+
ASSERT_TRUE(HasCircle(g2));
98+
99+
auto adj_list = BuildOperationAdjList(g2);
100+
for (auto& adj : adj_list) {
101+
auto& adj_set = adj.second;
102+
if (adj.first->Name() == "op1") {
103+
ASSERT_EQ((*adj_set.begin())->Name(), "op2");
104+
} else if (adj.first->Name() == "op2") {
105+
ASSERT_EQ((*adj_set.begin())->Name(), "op1");
106+
} else {
107+
ASSERT_TRUE(false);
108+
}
109+
}
110+
111+
Graph g3(prog);
112+
BuildNoCircleGraph(&g3);
113+
ASSERT_FALSE(HasCircle(g3));
114+
auto sorted = TopologySortOperations(g3);
115+
std::map<std::string, size_t> node_map;
116+
for (size_t i = 0; i < sorted.size(); ++i) {
117+
node_map[sorted[i]->Name()] = i;
118+
}
119+
ASSERT_EQ(node_map.at("op1"), 0);
120+
ASSERT_EQ(node_map.at("op2"), 1);
121+
ASSERT_TRUE(node_map.at("op3") < node_map.at("op5"));
122+
}
123+
} // namespace ir
124+
} // namespace framework
125+
} // namespace paddle

0 commit comments

Comments
 (0)