Skip to content

Commit 93355cc

Browse files
committed
fix control deps
1 parent f6d99d1 commit 93355cc

14 files changed

+155
-61
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto)
1+
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
22
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
33
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
44
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
9494
}
9595

9696
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
97-
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
97+
const std::vector<ir::Node *> &nodes) const {
9898
std::vector<std::string> send_vars;
9999
// since parameters are all in block 0,
100100
// it's enough to only scan send ops in block 0
101101
for (auto &node : nodes) {
102-
if (node->NodeType() != ir::Node::Type::kOperation) continue;
103102
OpDesc *op = node->Op();
104103
// TODO(Yancey1989): use a graceful method to find send op,
105104
// instead of the the hard code string
@@ -114,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
114113
}
115114

116115
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
117-
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
116+
const std::vector<ir::Node *> &nodes) const {
118117
std::vector<std::string> recv_vars;
119118
for (auto &node : nodes) {
120-
if (node->NodeType() != ir::Node::Type::kOperation) continue;
121119
OpDesc *op = node->Op();
122120
// TODO(Yancey1989): use a graceful method to find recv op,
123121
// instead of the hard code string
@@ -214,25 +212,36 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
214212
}
215213
}
216214

215+
// Verify that no operations before optimize ops depends on optimize ops.
216+
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
217+
optimize_ops.end());
218+
for (size_t i = 0; i < last_backward; ++i) {
219+
for (ir::Node *in : sorted_ret[i]->inputs) {
220+
for (ir::Node *pre_n : in->inputs) {
221+
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
222+
"optimize operations cannot be depended by forward "
223+
"or backward node %s -> %s",
224+
pre_n->Name(), sorted_ret[i]->Name());
225+
}
226+
}
227+
}
217228
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
218229
optimize_ops.end());
219230
return sorted_ret;
220231
}
221232

222233
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
223234
std::unique_ptr<ir::Graph> graph) const {
224-
// Rebuild the graph structure.
235+
// Give the topology sort order and rebuild the graph structure.
225236
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
226-
auto nodes = std::move(graph->nodes);
227-
graph->nodes.clear();
237+
auto nodes = graph->ReleaseNodes();
238+
ir::Graph &result = *graph;
228239

229240
for (auto &node : nodes) {
230241
if (node->NodeType() == ir::Node::Type::kVariable) {
231242
all_vars_.emplace(node->Name(), node->Var());
232243
}
233244
}
234-
235-
ir::Graph &result = *graph;
236245
std::unordered_set<std::string> og_has_been_broadcast;
237246

238247
// We cannot invoke resize. It is a bug of GCC 4.8
@@ -242,8 +251,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
242251

243252
// find send/recv vars so that we can place the distributed training
244253
// realted op in the place 0
245-
auto send_vars = FindDistTrainSendVars(nodes);
246-
auto recv_vars = FindDistTrainRecvVars(nodes);
254+
auto send_vars = FindDistTrainSendVars(sorted_ops);
255+
auto recv_vars = FindDistTrainRecvVars(sorted_ops);
247256

248257
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
249258
bcast_var_name_set.resize(places_.size());
@@ -589,8 +598,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
589598
const std::string &prev_op_name) const {
590599
for (auto &prev_op : result->Get<GraphOps>("ops")) {
591600
if (prev_op->Name() == prev_op_name) {
592-
auto *dep_var = new DummyVarHandle(
593-
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
601+
auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
594602
prev_op->AddOutput(dep_var);
595603
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
596604
op->AddInput(dep_var);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
7676
const std::vector<std::string> &recv_vars) const;
7777

7878
std::vector<std::string> FindDistTrainSendVars(
79-
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
79+
const std::vector<ir::Node *> &nodes) const;
8080

8181
std::vector<std::string> FindDistTrainRecvVars(
82-
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
82+
const std::vector<ir::Node *> &nodes) const;
8383

8484
void ConnectOp(ir::Graph *result, OpHandleBase *op,
8585
const std::string &prev_op_name) const;

paddle/fluid/framework/details/rpc_op_handle.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ void RPCOpHandle::RunImpl() {
3333
for (auto *in : inputs_) {
3434
auto &p = static_cast<VarHandle *>(in)->place_;
3535
// FIXME(Yancey1989): need a better solution instead of use DebugString()
36-
if (in->DebugString() == "dummy") { // HACK
36+
if (in->Node()->Name().find(ir::Node::kControlDepVarName) !=
37+
std::string::npos) { // HACK
3738
continue;
3839
}
3940
if (in->GeneratedOp()) {

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,36 @@
1717
namespace paddle {
1818
namespace framework {
1919
namespace details {
20+
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
21+
for (auto &var_map : graph->Get<GraphVars>("vars")) {
22+
for (auto &name_pair : var_map) {
23+
if (name_pair.second.size() <= 1) {
24+
continue;
25+
}
26+
auto it_new = name_pair.second.rbegin();
27+
auto it_old = name_pair.second.rbegin();
28+
++it_old;
29+
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
30+
OpHandleBase *write_op = (*it_new)->GeneratedOp();
31+
const auto &read_ops = (*it_old)->PendingOps();
32+
33+
for (auto *read_op : read_ops) {
34+
// Manually add a dependency var from read_op to write_op;
35+
if (read_op == write_op) {
36+
// Read Write is the same op.
37+
continue;
38+
}
39+
40+
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
41+
read_op->AddOutput(dep_var);
42+
write_op->AddInput(dep_var);
43+
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
44+
}
45+
}
46+
}
47+
}
48+
}
49+
2050
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
2151
ir::Graph *graph, ir::Node *node, const platform::Place &place,
2252
size_t place_offset) {
@@ -56,8 +86,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
5686
if (!op->Outputs().empty()) {
5787
continue;
5888
}
59-
auto *dummy_leaf = new DummyVarHandle(
60-
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
89+
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
6190
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
6291
op->AddOutput(dummy_leaf);
6392
}

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ class SSAGraphBuilder : public ir::Pass {
5757
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
5858

5959
protected:
60+
/**
61+
* We only handle write after read(WAR), since it should not have a write
62+
* after write in program. If there are write after write operators, we need
63+
* prune them.
64+
*
65+
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
66+
*/
67+
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
68+
6069
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
6170
const platform::Place &place,
6271
size_t place_offset);

paddle/fluid/framework/details/var_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const {
2626
return ss.str();
2727
}
2828

29-
std::string DummyVarHandle::DebugString() const { return "dummy"; }
29+
std::string DummyVarHandle::DebugString() const { return node_->Name(); }
3030
} // namespace details
3131
} // namespace framework
3232
} // namespace paddle

paddle/fluid/framework/ir/graph.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
3434
std::map<std::string, std::vector<ir::Node *>> var_nodes;
3535
for (auto *op : program.Block(0).AllOps()) {
3636
ir::Node *node = CreateOpNode(op);
37-
37+
// For input args, reuse the same var name if it was created before.
38+
// Otherwise, create a new one.
3839
for (auto &each_var_name : op->InputArgumentNames()) {
3940
ir::Node *var = nullptr;
4041
if (var_nodes.find(each_var_name) != var_nodes.end()) {
@@ -43,16 +44,16 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
4344
var = CreateVarNode(all_vars.at(each_var_name));
4445
var_nodes[each_var_name].push_back(var);
4546
} else {
46-
// TODO(paddle-dev): Seems some assumption doesn't hold?
47-
VLOG(3) << op->Type()
48-
<< " input var not in all_var list: " << each_var_name;
47+
// Operation input var can be optional (dispensable). Which means
48+
// the operation doesn't really need the var at runtime. In this
49+
// case, the no-existed var is ready at the beginning.
4950
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
5051
var_nodes[each_var_name].push_back(var);
5152
}
5253
node->inputs.push_back(var);
5354
var->outputs.push_back(node);
5455
}
55-
56+
// For output args, always create a new var.
5657
for (auto &each_var_name : op->OutputArgumentNames()) {
5758
ir::Node *var = CreateVarNode(all_vars.at(each_var_name));
5859
var_nodes[each_var_name].push_back(var);
@@ -67,6 +68,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
6768
*
6869
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
6970
*/
71+
7072
for (auto &var : var_nodes) {
7173
auto &versions = var.second;
7274
if (versions.size() <= 1) continue;
@@ -85,8 +87,18 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
8587
// Read Write is the same op.
8688
continue;
8789
}
88-
ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
89-
ir::Node::Type::kVariable);
90+
// 2 ops might have been connected via other vars.
91+
bool has_dep = false;
92+
for (ir::Node *r_out : read_op->outputs) {
93+
for (ir::Node *w_in : write_op->inputs) {
94+
if (r_out == w_in) {
95+
has_dep = true;
96+
break;
97+
}
98+
}
99+
}
100+
if (has_dep) continue;
101+
ir::Node *dep_var = CreateControlDepVar();
90102
read_op->outputs.push_back(dep_var);
91103
dep_var->inputs.push_back(read_op);
92104
write_op->inputs.push_back(dep_var);

paddle/fluid/framework/ir/graph.h

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License. */
2727
namespace paddle {
2828
namespace framework {
2929
namespace ir {
30+
3031
class Graph {
3132
public:
3233
explicit Graph(const ProgramDesc &program);
@@ -54,28 +55,58 @@ class Graph {
5455
};
5556
}
5657

58+
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
59+
5760
ir::Node *CreateVarNode(VarDesc *var_desc) {
58-
nodes.emplace_back(new ir::Node(var_desc));
59-
return nodes.back().get();
61+
return AddNode(new ir::Node(var_desc));
6062
}
6163

6264
ir::Node *CreateOpNode(OpDesc *op_desc) {
63-
nodes.emplace_back(new ir::Node(op_desc));
64-
return nodes.back().get();
65+
return AddNode(new ir::Node(op_desc));
66+
}
67+
68+
ir::Node *CreateControlDepVar() {
69+
// TODO(panyx0718): control var name should be unique.
70+
const std::string name = string::Sprintf(
71+
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
72+
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
6573
}
6674

6775
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
68-
nodes.emplace_back(new ir::Node(name, type));
69-
return nodes.back().get();
76+
return AddNode(new ir::Node(name, type));
7077
}
7178

72-
std::vector<std::unique_ptr<ir::Node>> nodes;
79+
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
80+
std::vector<std::unique_ptr<ir::Node>> ret;
81+
for (auto &n : nodes_) {
82+
ret.emplace_back(n.second.release());
83+
}
84+
nodes_.clear();
85+
node_set_.clear();
86+
return ret;
87+
}
7388

7489
private:
90+
// This method takes ownership of `node`.
91+
ir::Node *AddNode(ir::Node *node) {
92+
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
93+
nodes_[node].reset(node);
94+
node_set_.insert(node);
95+
return node;
96+
}
97+
98+
void RemoveNode(ir::Node *node) {
99+
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
100+
node_set_.erase(node);
101+
nodes_.erase(node);
102+
}
103+
75104
// NOTE: program_ shouldn't be exposed to user.
76105
const ProgramDesc &program_;
77106
std::map<std::string, boost::any> attrs_;
78107
std::map<std::string, std::function<void(void)>> attr_dels_;
108+
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
109+
std::unordered_set<ir::Node *> node_set_;
79110
};
80111
} // namespace ir
81112
} // namespace framework

paddle/fluid/framework/ir/graph_helper.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ void SortHelper(
3333
}
3434
}
3535

36-
LOG(ERROR) << "topology sort insert: " << node->Name()
37-
<< reinterpret_cast<void *>(node) << " input "
38-
<< node->inputs.size();
36+
VLOG(3) << "topology sort insert: " << node->Name()
37+
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
3938
ret->push_back(node);
4039
}
4140

@@ -93,18 +92,18 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
9392
const Graph &graph) {
9493
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
9594

96-
for (auto &n : graph.nodes) {
95+
for (auto &n : graph.Nodes()) {
9796
if (n->NodeType() != ir::Node::Type::kOperation) continue;
98-
if (adj_list.find(n.get()) == adj_list.end()) {
99-
adj_list[n.get()] = std::unordered_set<ir::Node *>();
97+
if (adj_list.find(n) == adj_list.end()) {
98+
adj_list[n] = std::unordered_set<ir::Node *>();
10099
}
101100
for (auto &var : n->inputs) {
102101
for (auto &adj_n : var->inputs) {
103102
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
104-
adj_list[n.get()].insert(adj_n);
105-
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
106-
<< " -> " << n->Name() << reinterpret_cast<void *>(n.get())
107-
<< " via " << var->Name() << reinterpret_cast<void *>(var);
103+
adj_list[n].insert(adj_n);
104+
VLOG(3) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
105+
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
106+
<< " via " << var->Name() << reinterpret_cast<void *>(var);
108107
}
109108
}
110109
}

0 commit comments

Comments
 (0)