@@ -19,31 +19,43 @@ limitations under the License. */
19
19
namespace paddle {
20
20
namespace framework {
21
21
22
+ // NOTE(paddle-dev): This graph contains circle.
22
23
Graph::Graph (const ProgramDesc &program) : program_(program) {
23
24
std::unordered_map<std::string, VarDesc *> all_vars;
24
25
for (auto *var : program.Block (0 ).AllVars ()) {
25
26
all_vars.emplace (var->Name (), var);
26
27
}
27
28
29
+ std::map<std::string, ir::Node *> var_nodes;
28
30
for (auto *op : program.Block (0 ).AllOps ()) {
29
31
ir::Node *node = CreateOpNode (op);
30
32
31
33
for (auto &each_var_name : op->InputArgumentNames ()) {
32
34
ir::Node *var = nullptr ;
33
- if (all_vars.count (each_var_name) != 0 ) {
35
+ if (var_nodes.find (each_var_name) != var_nodes.end ()) {
36
+ var = var_nodes.at (each_var_name);
37
+ } else if (all_vars.count (each_var_name) != 0 ) {
34
38
var = CreateVarNode (all_vars.at (each_var_name));
39
+ var_nodes[each_var_name] = var;
35
40
} else {
36
41
// TODO(paddle-dev): Seems some assumption doesn't hold?
37
42
LOG (ERROR) << op->Type ()
38
43
<< " input var not in all_var list: " << each_var_name;
39
44
var = CreateEmptyNode (each_var_name);
45
+ var_nodes[each_var_name] = var;
40
46
}
41
47
node->inputs .push_back (var);
42
48
var->outputs .push_back (node);
43
49
}
44
50
45
51
for (auto &each_var_name : op->OutputArgumentNames ()) {
46
- ir::Node *var = CreateVarNode (all_vars.at (each_var_name));
52
+ ir::Node *var = nullptr ;
53
+ if (var_nodes.find (each_var_name) != var_nodes.end ()) {
54
+ var = var_nodes.at (each_var_name);
55
+ } else {
56
+ var = CreateVarNode (all_vars.at (each_var_name));
57
+ var_nodes[each_var_name] = var;
58
+ }
47
59
node->outputs .push_back (var);
48
60
var->inputs .push_back (node);
49
61
}
0 commit comments