Skip to content

Commit 5173a53

Browse files
committed
fix reorder issue.
1 parent 21a4542 commit 5173a53

File tree

5 files changed

+107
-47
lines changed

5 files changed

+107
-47
lines changed

doc/fluid/design/ir/draft.md

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
## Motivation
22

3-
There is a ```gap``` between the ```Program``` defined by
4-
user and the ```Executable``` that can be scheduled
3+
There is a `gap` between the `Program` defined by
4+
user and the `Executable` that can be scheduled
55
efficiently on heterogeneous hardware, either locally
66
or distributedly.
77

8-
Usually, the ```gap``` is bridged by
8+
Usually, the `gap` is bridged by
99

1010
* A serious transformations with defined order.
1111

1212
* These transformations usually involve
13-
```insert, delete, clustering, split, dependency analysis```.
13+
`insert, delete, clustering, split, dependency analysis`.
1414

1515
* Has a simple way to verify and debug each transformation.
1616

@@ -38,44 +38,44 @@ design below.
3838

3939
#### Node
4040

41-
```Node``` represents an operation that performs some computation or
41+
`Node` represents an operation that performs some computation or
4242
a variable that is input or output of operation.
4343

44-
```Node```s are connected to other ```Node```s via inputs and outputs.
44+
`Node`s are connected to other `Node`s via inputs and outputs.
4545

4646
Other properties (maybe device placement information) can be added
47-
to ```Node``` in the future if it's a
48-
common requirement of many other ```Pass```es. Otherwise, it should live
49-
in a ```Node``` wrapper class that is private to some ```Pass``` or be
50-
a local member of a ```Pass```.
47+
to `Node` in the future if it's a
48+
common requirement of many other `Pass`es. Otherwise, it should live
49+
in a `Node` wrapper class that is private to some `Pass` or be
50+
a local member of a `Pass`.
5151

5252
#### Graph
5353

54-
```Graph``` contains a list of ```Node```s, which are connected to
54+
`Graph` contains a list of `Node`s, which are connected to
5555
each other via inputs and outputs.
5656

5757
TODO: Better definitions for the graph.
5858

59-
```Graph``` can also contain ```Attribute```s. ```Attribute```s
60-
can be ``any`` thing. For example, it can be a list of "wraper"
61-
nodes. The ```wrapper``` nodes compose ```Node```s and provide
62-
helper method for execution or transformation. ```Attribute```
59+
`Graph` can also contain `Attribute`s. `Attribute`s
60+
can be `any` thing. For example, it can be a list of "wraper"
61+
nodes. The `wrapper` nodes compose `Node`s and provide
62+
helper method for execution or transformation. `Attribute`
6363
can also contain other things that describe some properties of
64-
the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed
65-
across ```Pass```. However, it should be used with care.
64+
the `Graph` or `Graph` nodes. `Attribute` can be passed
65+
across `Pass`. However, it should be used with care.
6666

6767
#### Pass
6868

69-
```Pass``` represents a transformation of ```Graph```. Its input
70-
is a ```Graph``` and its output is also a ```Graph```. For example,
71-
a ```Pass``` can simply print out the ```Graph```. A ```Pass```
72-
can also fuse some ```Graph```'s ```Node```s.
69+
`Pass` represents a transformation of `Graph`. Its input
70+
is a `Graph` and its output is also a `Graph`. For example,
71+
a `Pass` can simply print out the `Graph`. A `Pass`
72+
can also fuse some `Graph`'s `Node`s.
7373

7474
#### Optimize
7575

76-
```Optimize``` contains a series of ```Pass``` with defined order.
77-
```Optimize``` transforms a ```Graph``` that only contains raw
78-
modeling logic to a ```Graph``` that can be run efficiently while
76+
`Optimize` contains a series of `Pass` with defined order.
77+
`Optimize` transforms a `Graph` that only contains raw
78+
modeling logic to a `Graph` that can be run efficiently while
7979
maintaining the original modeling logic.
8080

8181

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -196,38 +196,46 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
196196
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
197197
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
198198
size_t last_backward = 0;
199-
std::vector<ir::Node *> optimize_ops;
200-
std::vector<ir::Node *> sorted_ret;
201199
for (size_t i = 0; i < ret.size(); ++i) {
202200
if (boost::get<int>(
203201
ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
204202
static_cast<int>(OpRole::kBackward)) {
205-
sorted_ret.push_back(ret[i]);
206-
last_backward = sorted_ret.size();
207-
} else if (boost::get<int>(ret[i]->Op()->GetAttr(
208-
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
209-
static_cast<int>(OpRole::kOptimize)) {
210-
optimize_ops.push_back(ret[i]);
211-
} else {
212-
sorted_ret.push_back(ret[i]);
203+
last_backward = i;
213204
}
214205
}
215206

216-
// Verify that no operations before optimize ops depends on optimize ops.
217-
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
218-
optimize_ops.end());
219-
for (size_t i = 0; i < last_backward; ++i) {
220-
for (ir::Node *in : sorted_ret[i]->inputs) {
221-
for (ir::Node *pre_n : in->inputs) {
222-
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
223-
"optimize operations cannot be depended by forward "
224-
"or backward node %s -> %s",
225-
pre_n->Name(), sorted_ret[i]->Name());
207+
std::vector<ir::Node *> optimize_ops;
208+
std::vector<ir::Node *> sorted_ret;
209+
for (size_t i = 0; i < ret.size(); ++i) {
210+
if (i < last_backward) {
211+
if (boost::get<int>(ret[i]->Op()->GetAttr(
212+
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
213+
static_cast<int>(OpRole::kOptimize)) {
214+
optimize_ops.push_back(ret[i]);
215+
} else {
216+
sorted_ret.push_back(ret[i]);
217+
}
218+
} else if (i == last_backward) {
219+
sorted_ret.push_back(ret[i]);
220+
// Verify that no operations before optimize ops depends on optimize ops.
221+
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
222+
optimize_ops.end());
223+
for (ir::Node *n : sorted_ret) {
224+
for (ir::Node *in : n->inputs) {
225+
for (ir::Node *pre_n : in->inputs) {
226+
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
227+
"optimize operations cannot be depended by forward "
228+
"or backward node %s -> %s",
229+
pre_n->Name(), n->Name());
230+
}
231+
}
226232
}
233+
sorted_ret.insert(sorted_ret.end(), optimize_ops.begin(),
234+
optimize_ops.end());
235+
} else {
236+
sorted_ret.push_back(ret[i]);
227237
}
228238
}
229-
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
230-
optimize_ops.end());
231239
return sorted_ret;
232240
}
233241

@@ -239,7 +247,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
239247
ir::Graph &result = *graph;
240248

241249
for (auto &node : nodes) {
242-
if (node->NodeType() == ir::Node::Type::kVariable) {
250+
if (node->NodeType() == ir::Node::Type::kVariable && node->Var()) {
243251
all_vars_.emplace(node->Name(), node->Var());
244252
}
245253
}
@@ -361,6 +369,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
361369
}
362370
}
363371
}
372+
/*
373+
Dependency graph has been constructed. However, there are still data
374+
hazards need to be handled.
375+
*/
376+
PolishGraphToSupportDataHazards(&result);
364377

365378
/*
366379
* Only variables should be the leaves of graph.

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,46 @@
1717
namespace paddle {
1818
namespace framework {
1919
namespace details {
20+
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
21+
for (auto &var_map : graph->Get<GraphVars>("vars")) {
22+
for (auto &name_pair : var_map) {
23+
if (name_pair.second.size() <= 1) {
24+
continue;
25+
}
26+
auto it_new = name_pair.second.rbegin();
27+
auto it_old = name_pair.second.rbegin();
28+
++it_old;
29+
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
30+
OpHandleBase *write_op = (*it_new)->GeneratedOp();
31+
const auto &read_ops = (*it_old)->PendingOps();
32+
33+
for (auto *read_op : read_ops) {
34+
// Manually add a dependency var from read_op to write_op;
35+
if (read_op == write_op) {
36+
// Read Write is the same op.
37+
continue;
38+
}
39+
bool has_dep = false;
40+
for (auto *r_out : read_op->Outputs()) {
41+
for (auto *w_in : write_op->Inputs()) {
42+
if (r_out->Node() == w_in->Node()) {
43+
has_dep = true;
44+
break;
45+
}
46+
}
47+
}
48+
if (has_dep) continue;
49+
50+
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
51+
read_op->AddOutput(dep_var);
52+
write_op->AddInput(dep_var);
53+
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
54+
}
55+
}
56+
}
57+
}
58+
}
59+
2060
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
2161
ir::Graph *graph, ir::Node *node, const platform::Place &place,
2262
size_t place_offset) {

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class SSAGraphBuilder : public ir::Pass {
5757
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
5858

5959
protected:
60+
/*
61+
Dependency graph has been constructed. However, there are still data
62+
hazards need to be handled.
63+
*/
64+
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
65+
6066
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
6167
const platform::Place &place,
6268
size_t place_offset);

paddle/fluid/framework/ir/graph.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
9898
}
9999
}
100100
if (has_dep) continue;
101+
101102
ir::Node *dep_var = CreateControlDepVar();
102103
read_op->outputs.push_back(dep_var);
103104
dep_var->inputs.push_back(read_op);

0 commit comments

Comments
 (0)