@@ -23,39 +23,6 @@ limitations under the License. */
23
23
namespace paddle {
24
24
namespace framework {
25
25
namespace ir {
26
- /*
27
- namespace {
28
- void SortHelper(
29
- const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
30
- ir::Node *node, std::unordered_set<ir::Node *> *visited,
31
- std::vector<ir::Node *> *ret) {
32
- visited->insert(node);
33
-
34
- for (auto adj : adj_list.at(node)) {
35
- if (visited->find(adj) == visited->end()) {
36
- SortHelper(adj_list, adj, visited, ret);
37
- }
38
- }
39
-
40
- VLOG(3) << "topology sort insert: " << node->Name()
41
- << reinterpret_cast<void *>(node) << " input " << node->inputs.size();
42
- ret->push_back(node);
43
- }
44
-
45
- std::vector<ir::Node*> TopologySortOperations(
46
- const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
47
- std::unordered_set<ir::Node *> visited;
48
- std::vector<ir::Node *> ret;
49
-
50
- for (auto adj : adj_list) {
51
- if (visited.find(adj.first) == visited.end()) {
52
- SortHelper(adj_list, adj.first, &visited, &ret);
53
- }
54
- }
55
- return ret;
56
- }
57
- } // namespace
58
- */
59
26
60
27
Graph::Graph (const ProgramDesc &program) : program_(program) {
61
28
VLOG (3 ) << " block in program:" << program_.Size ();
@@ -93,6 +60,13 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
93
60
var->inputs .push_back (node);
94
61
}
95
62
}
63
+ /* *
64
+ * We only handle write after read(WAR), since it should not have a write
65
+ * after write in program. If there are write after write operators, we need
66
+ * prune them.
67
+ *
68
+ * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
69
+ */
96
70
for (auto &var : var_nodes) {
97
71
auto &versions = var.second ;
98
72
if (versions.size () <= 1 ) continue ;
@@ -121,121 +95,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
121
95
}
122
96
}
123
97
}
124
-
125
- /*
126
- bool HasCircleHelper(ir::Node* node,
127
- const std::map<ir::Node *, std::unordered_set<ir::Node *>>
128
- &adj_list,
129
- std::unordered_set<ir::Node*>* visited,
130
- std::unordered_set<ir::Node*>* in_trace) {
131
- if (visited->find(node) == visited->end()) {
132
- visited->insert(node);
133
- in_trace->insert(node);
134
-
135
- for (ir::Node *in : adj_list.at(node)) {
136
- if (visited->find(in) == visited->end() &&
137
- HasCircleHelper(in, adj_list, visited, in_trace)) {
138
- return true;
139
- } else if (in_trace->find(in) != in_trace->end()) {
140
- return true;
141
- }
142
- }
143
- }
144
- in_trace->erase(node);
145
- return false;
146
- }
147
-
148
- bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>>
149
- &adj_list) {
150
- std::unordered_set<ir::Node*> visited;
151
- std::unordered_set<ir::Node*> in_trace;
152
- for (auto& adj : adj_list) {
153
- if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) {
154
- return true;
155
- }
156
- }
157
- return false;
158
- }
159
-
160
- std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
161
- const std::vector<ir::Node*> &nodes) {
162
- std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
163
-
164
- for (auto &n : nodes) {
165
- if (n->NodeType() != ir::Node::Type::kOperation) continue;
166
- if (adj_list.find(n) == adj_list.end()) {
167
- adj_list[n] = std::unordered_set<ir::Node *>();
168
- }
169
- for (auto &var : n->inputs) {
170
- for (auto &adj_n : var->inputs) {
171
- PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
172
- adj_list[n].insert(adj_n);
173
- LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
174
- << " -> " << n->Name() << reinterpret_cast<void *>(n)
175
- << " via " << var->Name() << reinterpret_cast<void *>(var);
176
- }
177
- }
178
- }
179
- return adj_list;
180
- }
181
-
182
- std::vector<ir::Node *> TopologySortOperationsOperationFromInToOut(
183
- const std::vector<std::unique_ptr<ir::Node>> &nodes) {
184
- std::vector<ir::Node*> tmp;
185
- for (auto& n : nodes) {
186
- tmp.push_back(n.get());
187
- }
188
- std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
189
- BuildOperationAdjList(tmp);
190
-
191
- PADDLE_ENFORCE(!HasCircle(adj_list));
192
- std::vector<ir::Node*> ret = TopologySortOperations(adj_list);
193
-
194
- ir::Node *last_backward = nullptr;
195
- std::vector<ir::Node *> optimize_ops;
196
- for (ir::Node* n : ret) {
197
- if (boost::get<int>(
198
- n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
199
- static_cast<int>(OpRole::kBackward)) {
200
- last_backward = n;
201
- } else if (boost::get<int>(
202
- n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
203
- static_cast<int>(OpRole::kOptimize)) {
204
- optimize_ops.push_back(n);
205
- }
206
- }
207
-
208
- if (last_backward) {
209
- for (ir::Node *opt_node : optimize_ops) {
210
- ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
211
- ir::Node::Type::kVariable);
212
- last_backward->outputs.push_back(dep_var);
213
- dep_var->inputs.push_back(last_backward);
214
- opt_node->inputs.push_back(dep_var);
215
- dep_var->outputs.push_back(opt_node);
216
- VLOG(3) << "appending connect: " << last_backward->Name()
217
- << reinterpret_cast<void *>(last_backward) << "->"
218
- << opt_node->Name() << reinterpret_cast<void *>(opt_node);
219
- }
220
- }
221
-
222
- PADDLE_ENFORCE(!HasCircle(adj_list));
223
- for (ir::Node *n : ret) {
224
- std::unordered_set<ir::Node *> dummy;
225
- n->inputs.erase(
226
- std::remove_if(n->inputs.begin(), n->inputs.end(),
227
- [n](ir::Node *t) {
228
- return t->Name() == ir::Node::kControlDepVarName; }),
229
- n->inputs.end());
230
- n->outputs.erase(
231
- std::remove_if(n->outputs.begin(), n->outputs.end(),
232
- [n](ir::Node *t) {
233
- return t->Name() == ir::Node::kControlDepVarName; }),
234
- n->outputs.end());
235
- }
236
- return ret;
237
- }*/
238
-
239
98
} // namespace ir
240
99
} // namespace framework
241
100
} // namespace paddle
0 commit comments