@@ -22,6 +22,7 @@ limitations under the License. */
22
22
23
23
namespace paddle {
24
24
namespace framework {
25
+ /*
25
26
namespace {
26
27
void SortHelper(
27
28
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
@@ -39,7 +40,21 @@ void SortHelper(
39
40
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
40
41
ret->push_back(node);
41
42
}
43
+
44
+ std::vector<ir::Node*> TopologySort(
45
+ const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
46
+ std::unordered_set<ir::Node *> visited;
47
+ std::vector<ir::Node *> ret;
48
+
49
+ for (auto adj : adj_list) {
50
+ if (visited.find(adj.first) == visited.end()) {
51
+ SortHelper(adj_list, adj.first, &visited, &ret);
52
+ }
53
+ }
54
+ return ret;
55
+ }
42
56
} // namespace
57
+ */
43
58
44
59
Graph::Graph (const ProgramDesc &program) : program_(program) {
45
60
VLOG (3 ) << " block in program:" << program_.Size ();
@@ -48,20 +63,9 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
48
63
all_vars.emplace (var->Name (), var);
49
64
}
50
65
51
- ir::Node *last_backward = nullptr ;
52
- std::vector<ir::Node *> optimize_ops;
53
66
std::map<std::string, std::vector<ir::Node *>> var_nodes;
54
67
for (auto *op : program.Block (0 ).AllOps ()) {
55
68
ir::Node *node = CreateOpNode (op);
56
- if (boost::get<int >(
57
- op->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ())) ==
58
- static_cast <int >(OpRole::kBackward )) {
59
- last_backward = node;
60
- } else if (boost::get<int >(
61
- op->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ())) ==
62
- static_cast <int >(OpRole::kOptimize )) {
63
- optimize_ops.push_back (node);
64
- }
65
69
66
70
for (auto &each_var_name : op->InputArgumentNames ()) {
67
71
ir::Node *var = nullptr ;
@@ -106,70 +110,130 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
106
110
// Read Write is the same op.
107
111
continue ;
108
112
}
109
- ir::Node *dep_var = CreateEmptyNode (" dummy" , ir::Node::Type::kVariable );
113
+ ir::Node *dep_var = CreateEmptyNode (ir::Node::kControlDepVarName ,
114
+ ir::Node::Type::kVariable );
110
115
read_op->outputs .push_back (dep_var);
111
116
dep_var->inputs .push_back (read_op);
112
117
write_op->inputs .push_back (dep_var);
113
118
dep_var->outputs .push_back (write_op);
114
119
}
115
120
}
116
121
}
122
+ }
117
123
118
- if (last_backward) {
119
- for (ir::Node *opt_node : optimize_ops) {
120
- ir::Node *dep_var = CreateEmptyNode (" dummy" , ir::Node::Type::kVariable );
121
- last_backward->outputs .push_back (dep_var);
122
- dep_var->inputs .push_back (last_backward);
123
- opt_node->inputs .push_back (dep_var);
124
- dep_var->outputs .push_back (opt_node);
125
- VLOG (3 ) << " appending connect: " << last_backward->Name ()
126
- << reinterpret_cast <void *>(last_backward) << " ->"
127
- << opt_node->Name () << reinterpret_cast <void *>(opt_node);
124
+ /*
125
+ bool HasCircleHelper(ir::Node* node,
126
+ const std::map<ir::Node *, std::unordered_set<ir::Node *>>
127
+ &adj_list,
128
+ std::unordered_set<ir::Node*>* visited,
129
+ std::unordered_set<ir::Node*>* in_trace) {
130
+ if (visited->find(node) == visited->end()) {
131
+ visited->insert(node);
132
+ in_trace->insert(node);
133
+
134
+ for (ir::Node *in : adj_list.at(node)) {
135
+ if (visited->find(in) == visited->end() &&
136
+ HasCircleHelper(in, adj_list, visited, in_trace)) {
137
+ return true;
138
+ } else if (in_trace->find(in) != in_trace->end()) {
139
+ return true;
140
+ }
128
141
}
129
142
}
143
+ in_trace->erase(node);
144
+ return false;
130
145
}
131
146
132
- std::vector<ir::Node *> TopologySortOperationFromInToOut (
133
- const std::vector<std::unique_ptr<ir::Node>> &nodes) {
147
+ bool HasCircle(const std::map<ir::Node *, std::unordered_set<ir::Node *>>
148
+ &adj_list) {
149
+ std::unordered_set<ir::Node*> visited;
150
+ std::unordered_set<ir::Node*> in_trace;
151
+ for (auto& adj : adj_list) {
152
+ if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) {
153
+ return true;
154
+ }
155
+ }
156
+ return false;
157
+ }
158
+
159
+ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildAdjList(
160
+ const std::vector<ir::Node*> &nodes) {
134
161
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
135
- std::unordered_set<ir::Node *> visited;
136
- std::vector<ir::Node *> ret;
137
162
138
163
for (auto &n : nodes) {
139
164
if (n->NodeType() != ir::Node::Type::kOperation) continue;
140
- if (adj_list.find (n. get () ) == adj_list.end ()) {
141
- adj_list[n. get () ] = std::unordered_set<ir::Node *>();
165
+ if (adj_list.find(n) == adj_list.end()) {
166
+ adj_list[n] = std::unordered_set<ir::Node *>();
142
167
}
143
168
for (auto &var : n->inputs) {
144
169
for (auto &adj_n : var->inputs) {
145
170
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
146
- adj_list[n. get () ].insert (adj_n);
171
+ adj_list[n].insert(adj_n);
147
172
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
148
- << " -> " << n->Name () << reinterpret_cast <void *>(n. get () )
173
+ << " -> " << n->Name() << reinterpret_cast<void *>(n)
149
174
<< " via " << var->Name() << reinterpret_cast<void *>(var);
150
175
}
151
176
}
152
177
}
178
+ return adj_list;
179
+ }
153
180
154
- for (auto adj : adj_list) {
155
- if (visited.find (adj.first ) == visited.end ()) {
156
- SortHelper (adj_list, adj.first , &visited, &ret);
181
+ std::vector<ir::Node *> TopologySortOperationFromInToOut(
182
+ const std::vector<std::unique_ptr<ir::Node>> &nodes) {
183
+ std::vector<ir::Node*> tmp;
184
+ for (auto& n : nodes) {
185
+ tmp.push_back(n.get());
186
+ }
187
+ std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
188
+ BuildAdjList(tmp);
189
+
190
+ PADDLE_ENFORCE(!HasCircle(adj_list));
191
+ std::vector<ir::Node*> ret = TopologySort(adj_list);
192
+
193
+ ir::Node *last_backward = nullptr;
194
+ std::vector<ir::Node *> optimize_ops;
195
+ for (ir::Node* n : ret) {
196
+ if (boost::get<int>(
197
+ n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
198
+ static_cast<int>(OpRole::kBackward)) {
199
+ last_backward = n;
200
+ } else if (boost::get<int>(
201
+ n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
202
+ static_cast<int>(OpRole::kOptimize)) {
203
+ optimize_ops.push_back(n);
157
204
}
158
205
}
159
206
207
+ if (last_backward) {
208
+ for (ir::Node *opt_node : optimize_ops) {
209
+ ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
210
+ ir::Node::Type::kVariable);
211
+ last_backward->outputs.push_back(dep_var);
212
+ dep_var->inputs.push_back(last_backward);
213
+ opt_node->inputs.push_back(dep_var);
214
+ dep_var->outputs.push_back(opt_node);
215
+ VLOG(3) << "appending connect: " << last_backward->Name()
216
+ << reinterpret_cast<void *>(last_backward) << "->"
217
+ << opt_node->Name() << reinterpret_cast<void *>(opt_node);
218
+ }
219
+ }
220
+
221
+ PADDLE_ENFORCE(!HasCircle(adj_list));
160
222
for (ir::Node *n : ret) {
161
223
std::unordered_set<ir::Node *> dummy;
162
224
n->inputs.erase(
163
225
std::remove_if(n->inputs.begin(), n->inputs.end(),
164
- [n](ir::Node *t) { return t->Name () == " dummy" ; }),
226
+ [n](ir::Node *t) {
227
+ return t->Name() == ir::Node::kControlDepVarName; }),
165
228
n->inputs.end());
166
229
n->outputs.erase(
167
230
std::remove_if(n->outputs.begin(), n->outputs.end(),
168
- [n](ir::Node *t) { return t->Name () == " dummy" ; }),
231
+ [n](ir::Node *t) {
232
+ return t->Name() == ir::Node::kControlDepVarName; }),
169
233
n->outputs.end());
170
234
}
171
235
return ret;
172
- }
236
+ }*/
173
237
174
238
} // namespace framework
175
239
} // namespace paddle
0 commit comments