Skip to content

Commit f6d99d1

Browse files
committed
polish
1 parent c3f6e0e commit f6d99d1

File tree

4 files changed

+7
-220
lines changed

4 files changed

+7
-220
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -216,21 +216,6 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
216216

217217
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
218218
optimize_ops.end());
219-
220-
for (ir::Node *n : sorted_ret) {
221-
n->inputs.erase(std::remove_if(n->inputs.begin(), n->inputs.end(),
222-
[n](ir::Node *t) {
223-
return t->Name() ==
224-
ir::Node::kControlDepVarName;
225-
}),
226-
n->inputs.end());
227-
n->outputs.erase(std::remove_if(n->outputs.begin(), n->outputs.end(),
228-
[n](ir::Node *t) {
229-
return t->Name() ==
230-
ir::Node::kControlDepVarName;
231-
}),
232-
n->outputs.end());
233-
}
234219
return sorted_ret;
235220
}
236221

@@ -365,12 +350,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
365350
}
366351
}
367352

368-
/*
369-
Dependency graph has been constructed. However, there are still data
370-
hazards need to be handled.
371-
*/
372-
PolishGraphToSupportDataHazards(&result);
373-
374353
/*
375354
* Only variables should be the leaves of graph.
376355
*/

paddle/fluid/framework/details/ssa_graph_builder.cc

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,6 @@
1717
namespace paddle {
1818
namespace framework {
1919
namespace details {
20-
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
21-
for (auto &var_map : graph->Get<GraphVars>("vars")) {
22-
for (auto &name_pair : var_map) {
23-
if (name_pair.second.size() <= 1) {
24-
continue;
25-
}
26-
auto it_new = name_pair.second.rbegin();
27-
auto it_old = name_pair.second.rbegin();
28-
++it_old;
29-
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
30-
OpHandleBase *write_op = (*it_new)->GeneratedOp();
31-
const auto &read_ops = (*it_old)->PendingOps();
32-
33-
for (auto *read_op : read_ops) {
34-
// Manually add a dependency var from read_op to write_op;
35-
if (read_op == write_op) {
36-
// Read Write is the same op.
37-
continue;
38-
}
39-
40-
bool has_dep = false;
41-
for (auto read_out : read_op->Outputs()) {
42-
for (auto write_in : write_op->Inputs()) {
43-
if (read_out == write_in) {
44-
has_dep = true;
45-
break;
46-
}
47-
}
48-
}
49-
if (has_dep) continue;
50-
51-
auto *dep_var = new DummyVarHandle(
52-
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
53-
read_op->AddOutput(dep_var);
54-
write_op->AddInput(dep_var);
55-
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
56-
}
57-
}
58-
}
59-
}
60-
}
61-
6220
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
6321
ir::Graph *graph, ir::Node *node, const platform::Place &place,
6422
size_t place_offset) {

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@ class SSAGraphBuilder : public ir::Pass {
5757
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
5858

5959
protected:
60-
/**
61-
* We only handle write after read(WAR), since it should not have a write
62-
* after write in program. If there are write after write operators, we need
63-
* prune them.
64-
*
65-
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
66-
*/
67-
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
68-
6960
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
7061
const platform::Place &place,
7162
size_t place_offset);

paddle/fluid/framework/ir/graph.cc

Lines changed: 7 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,6 @@ limitations under the License. */
2323
namespace paddle {
2424
namespace framework {
2525
namespace ir {
26-
/*
27-
namespace {
28-
void SortHelper(
29-
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
30-
ir::Node *node, std::unordered_set<ir::Node *> *visited,
31-
std::vector<ir::Node *> *ret) {
32-
visited->insert(node);
33-
34-
for (auto adj : adj_list.at(node)) {
35-
if (visited->find(adj) == visited->end()) {
36-
SortHelper(adj_list, adj, visited, ret);
37-
}
38-
}
39-
40-
VLOG(3) << "topology sort insert: " << node->Name()
41-
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
42-
ret->push_back(node);
43-
}
44-
45-
std::vector<ir::Node*> TopologySortOperations(
46-
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
47-
std::unordered_set<ir::Node *> visited;
48-
std::vector<ir::Node *> ret;
49-
50-
for (auto adj : adj_list) {
51-
if (visited.find(adj.first) == visited.end()) {
52-
SortHelper(adj_list, adj.first, &visited, &ret);
53-
}
54-
}
55-
return ret;
56-
}
57-
} // namespace
58-
*/
5926

6027
Graph::Graph(const ProgramDesc &program) : program_(program) {
6128
VLOG(3) << "block in program:" << program_.Size();
@@ -93,6 +60,13 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
9360
var->inputs.push_back(node);
9461
}
9562
}
63+
/**
64+
* We only handle write after read(WAR), since it should not have a write
65+
* after write in program. If there are write after write operators, we need
66+
* prune them.
67+
*
68+
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
69+
*/
9670
for (auto &var : var_nodes) {
9771
auto &versions = var.second;
9872
if (versions.size() <= 1) continue;
@@ -121,121 +95,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
12195
}
12296
}
12397
}
124-
125-
/*
126-
bool HasCircleHelper(ir::Node* node,
127-
const std::map<ir::Node *, std::unordered_set<ir::Node *>>
128-
&adj_list,
129-
std::unordered_set<ir::Node*>* visited,
130-
std::unordered_set<ir::Node*>* in_trace) {
131-
if (visited->find(node) == visited->end()) {
132-
visited->insert(node);
133-
in_trace->insert(node);
134-
135-
for (ir::Node *in : adj_list.at(node)) {
136-
if (visited->find(in) == visited->end() &&
137-
HasCircleHelper(in, adj_list, visited, in_trace)) {
138-
return true;
139-
} else if (in_trace->find(in) != in_trace->end()) {
140-
return true;
141-
}
142-
}
143-
}
144-
in_trace->erase(node);
145-
return false;
146-
}
147-
148-
bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>>
149-
&adj_list) {
150-
std::unordered_set<ir::Node*> visited;
151-
std::unordered_set<ir::Node*> in_trace;
152-
for (auto& adj : adj_list) {
153-
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) {
154-
return true;
155-
}
156-
}
157-
return false;
158-
}
159-
160-
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
161-
const std::vector<ir::Node*> &nodes) {
162-
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
163-
164-
for (auto &n : nodes) {
165-
if (n->NodeType() != ir::Node::Type::kOperation) continue;
166-
if (adj_list.find(n) == adj_list.end()) {
167-
adj_list[n] = std::unordered_set<ir::Node *>();
168-
}
169-
for (auto &var : n->inputs) {
170-
for (auto &adj_n : var->inputs) {
171-
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
172-
adj_list[n].insert(adj_n);
173-
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
174-
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
175-
<< " via " << var->Name() << reinterpret_cast<void *>(var);
176-
}
177-
}
178-
}
179-
return adj_list;
180-
}
181-
182-
std::vector<ir::Node *> TopologySortOperationsOperationFromInToOut(
183-
const std::vector<std::unique_ptr<ir::Node>> &nodes) {
184-
std::vector<ir::Node*> tmp;
185-
for (auto& n : nodes) {
186-
tmp.push_back(n.get());
187-
}
188-
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
189-
BuildOperationAdjList(tmp);
190-
191-
PADDLE_ENFORCE(!HasCircle(adj_list));
192-
std::vector<ir::Node*> ret = TopologySortOperations(adj_list);
193-
194-
ir::Node *last_backward = nullptr;
195-
std::vector<ir::Node *> optimize_ops;
196-
for (ir::Node* n : ret) {
197-
if (boost::get<int>(
198-
n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
199-
static_cast<int>(OpRole::kBackward)) {
200-
last_backward = n;
201-
} else if (boost::get<int>(
202-
n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
203-
static_cast<int>(OpRole::kOptimize)) {
204-
optimize_ops.push_back(n);
205-
}
206-
}
207-
208-
if (last_backward) {
209-
for (ir::Node *opt_node : optimize_ops) {
210-
ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
211-
ir::Node::Type::kVariable);
212-
last_backward->outputs.push_back(dep_var);
213-
dep_var->inputs.push_back(last_backward);
214-
opt_node->inputs.push_back(dep_var);
215-
dep_var->outputs.push_back(opt_node);
216-
VLOG(3) << "appending connect: " << last_backward->Name()
217-
<< reinterpret_cast<void *>(last_backward) << "->"
218-
<< opt_node->Name() << reinterpret_cast<void *>(opt_node);
219-
}
220-
}
221-
222-
PADDLE_ENFORCE(!HasCircle(adj_list));
223-
for (ir::Node *n : ret) {
224-
std::unordered_set<ir::Node *> dummy;
225-
n->inputs.erase(
226-
std::remove_if(n->inputs.begin(), n->inputs.end(),
227-
[n](ir::Node *t) {
228-
return t->Name() == ir::Node::kControlDepVarName; }),
229-
n->inputs.end());
230-
n->outputs.erase(
231-
std::remove_if(n->outputs.begin(), n->outputs.end(),
232-
[n](ir::Node *t) {
233-
return t->Name() == ir::Node::kControlDepVarName; }),
234-
n->outputs.end());
235-
}
236-
return ret;
237-
}*/
238-
23998
} // namespace ir
24099
} // namespace framework
241100
} // namespace paddle

0 commit comments

Comments
 (0)