@@ -12,55 +12,164 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include < algorithm>
16
+ #include < unordered_set>
17
+
15
18
#include " paddle/fluid/framework/ir/graph.h"
19
+ #include " paddle/fluid/framework/op_proto_maker.h"
16
20
#include " paddle/fluid/framework/program_desc.h"
17
21
#include " paddle/fluid/framework/var_desc.h"
18
22
19
23
namespace paddle {
20
24
namespace framework {
25
+ namespace {
26
+ void SortHelper (
27
+ const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
28
+ ir::Node *node, std::unordered_set<ir::Node *> *visited,
29
+ std::vector<ir::Node *> *ret) {
30
+ visited->insert (node);
31
+
32
+ for (auto adj : adj_list.at (node)) {
33
+ if (visited->find (adj) == visited->end ()) {
34
+ SortHelper (adj_list, adj, visited, ret);
35
+ }
36
+ }
37
+
38
+ VLOG (3 ) << " topology sort insert: " << node->Name ()
39
+ << reinterpret_cast <void *>(node) << " input " << node->inputs .size ();
40
+ ret->push_back (node);
41
+ }
42
+ } // namespace
21
43
22
- // NOTE(paddle-dev): This graph contains circle.
23
44
Graph::Graph (const ProgramDesc &program) : program_(program) {
24
45
VLOG (3 ) << " block in program:" << program_.Size ();
25
46
std::unordered_map<std::string, VarDesc *> all_vars;
26
47
for (auto *var : program.Block (0 ).AllVars ()) {
27
48
all_vars.emplace (var->Name (), var);
28
49
}
29
50
30
- std::map<std::string, ir::Node *> var_nodes;
51
+ ir::Node *last_backward = nullptr ;
52
+ std::vector<ir::Node *> optimize_ops;
53
+ std::map<std::string, std::vector<ir::Node *>> var_nodes;
31
54
for (auto *op : program.Block (0 ).AllOps ()) {
32
55
ir::Node *node = CreateOpNode (op);
56
+ if (boost::get<int >(
57
+ op->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ())) ==
58
+ static_cast <int >(OpRole::kBackward )) {
59
+ last_backward = node;
60
+ } else if (boost::get<int >(
61
+ op->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ())) ==
62
+ static_cast <int >(OpRole::kOptimize )) {
63
+ optimize_ops.push_back (node);
64
+ }
33
65
34
66
for (auto &each_var_name : op->InputArgumentNames ()) {
35
67
ir::Node *var = nullptr ;
36
68
if (var_nodes.find (each_var_name) != var_nodes.end ()) {
37
- var = var_nodes.at (each_var_name);
69
+ var = var_nodes.at (each_var_name). back () ;
38
70
} else if (all_vars.count (each_var_name) != 0 ) {
39
71
var = CreateVarNode (all_vars.at (each_var_name));
40
- var_nodes[each_var_name] = var;
72
+ var_nodes[each_var_name]. push_back ( var) ;
41
73
} else {
42
74
// TODO(paddle-dev): Seems some assumption doesn't hold?
43
75
VLOG (3 ) << op->Type ()
44
76
<< " input var not in all_var list: " << each_var_name;
45
77
var = CreateEmptyNode (each_var_name, ir::Node::Type::kVariable );
46
- var_nodes[each_var_name] = var;
78
+ var_nodes[each_var_name]. push_back ( var) ;
47
79
}
48
80
node->inputs .push_back (var);
49
81
var->outputs .push_back (node);
50
82
}
51
83
52
84
for (auto &each_var_name : op->OutputArgumentNames ()) {
53
- ir::Node *var = nullptr ;
54
- if (var_nodes.find (each_var_name) != var_nodes.end ()) {
55
- var = var_nodes.at (each_var_name);
56
- } else {
57
- var = CreateVarNode (all_vars.at (each_var_name));
58
- var_nodes[each_var_name] = var;
59
- }
85
+ ir::Node *var = CreateVarNode (all_vars.at (each_var_name));
86
+ var_nodes[each_var_name].push_back (var);
60
87
node->outputs .push_back (var);
61
88
var->inputs .push_back (node);
62
89
}
63
90
}
91
+ for (auto &var : var_nodes) {
92
+ auto &versions = var.second ;
93
+ if (versions.size () <= 1 ) continue ;
94
+
95
+ auto it_new = versions.rbegin ();
96
+ auto it_old = versions.rbegin ();
97
+ ++it_old;
98
+ for (; it_old != versions.rend (); it_new = it_old, ++it_old) {
99
+ ir::Node *write_op =
100
+ (*it_new)->inputs .empty () ? nullptr : (*it_new)->inputs [0 ];
101
+ const auto &read_ops = (*it_old)->outputs ;
102
+
103
+ for (auto *read_op : read_ops) {
104
+ // Manually add a dependency var from read_op to write_op;
105
+ if (read_op == write_op) {
106
+ // Read Write is the same op.
107
+ continue ;
108
+ }
109
+ ir::Node *dep_var = CreateEmptyNode (" dummy" , ir::Node::Type::kVariable );
110
+ read_op->outputs .push_back (dep_var);
111
+ dep_var->inputs .push_back (read_op);
112
+ write_op->inputs .push_back (dep_var);
113
+ dep_var->outputs .push_back (write_op);
114
+ }
115
+ }
116
+ }
117
+
118
+ if (last_backward) {
119
+ for (ir::Node *opt_node : optimize_ops) {
120
+ ir::Node *dep_var = CreateEmptyNode (" dummy" , ir::Node::Type::kVariable );
121
+ last_backward->outputs .push_back (dep_var);
122
+ dep_var->inputs .push_back (last_backward);
123
+ opt_node->inputs .push_back (dep_var);
124
+ dep_var->outputs .push_back (opt_node);
125
+ VLOG (3 ) << " appending connect: " << last_backward->Name ()
126
+ << reinterpret_cast <void *>(last_backward) << " ->"
127
+ << opt_node->Name () << reinterpret_cast <void *>(opt_node);
128
+ }
129
+ }
130
+ }
131
+
132
+ std::vector<ir::Node *> TopologySortOperationFromInToOut (
133
+ const std::vector<std::unique_ptr<ir::Node>> &nodes) {
134
+ std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
135
+ std::unordered_set<ir::Node *> visited;
136
+ std::vector<ir::Node *> ret;
137
+
138
+ for (auto &n : nodes) {
139
+ if (n->NodeType () != ir::Node::Type::kOperation ) continue ;
140
+ if (adj_list.find (n.get ()) == adj_list.end ()) {
141
+ adj_list[n.get ()] = std::unordered_set<ir::Node *>();
142
+ }
143
+ for (auto &var : n->inputs ) {
144
+ for (auto &adj_n : var->inputs ) {
145
+ PADDLE_ENFORCE (adj_n->NodeType () == ir::Node::Type::kOperation );
146
+ adj_list[n.get ()].insert (adj_n);
147
+ LOG (ERROR) << " adj " << adj_n->Name () << reinterpret_cast <void *>(adj_n)
148
+ << " -> " << n->Name () << reinterpret_cast <void *>(n.get ())
149
+ << " via " << var->Name () << reinterpret_cast <void *>(var);
150
+ }
151
+ }
152
+ }
153
+
154
+ for (auto adj : adj_list) {
155
+ if (visited.find (adj.first ) == visited.end ()) {
156
+ SortHelper (adj_list, adj.first , &visited, &ret);
157
+ }
158
+ }
159
+
160
+ for (ir::Node *n : ret) {
161
+ std::unordered_set<ir::Node *> dummy;
162
+ n->inputs .erase (
163
+ std::remove_if (n->inputs .begin (), n->inputs .end (),
164
+ [n](ir::Node *t) { return t->Name () == " dummy" ; }),
165
+ n->inputs .end ());
166
+ n->outputs .erase (
167
+ std::remove_if (n->outputs .begin (), n->outputs .end (),
168
+ [n](ir::Node *t) { return t->Name () == " dummy" ; }),
169
+ n->outputs .end ());
170
+ }
171
+ return ret;
64
172
}
173
+
65
174
} // namespace framework
66
175
} // namespace paddle
0 commit comments