@@ -40,6 +40,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
40
40
PADDLE_ENFORCE (graph);
41
41
PADDLE_ENFORCE (desc_);
42
42
// insert vars
43
+ // The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
44
+ // will keep updating to its latest alias during the graph-building.
43
45
std::unordered_map<std::string, size_t > var2id;
44
46
auto &main_block = desc_->blocks (framework::kRootBlockIndex );
45
47
for (int i = 0 ; i < main_block.vars_size (); i++) {
@@ -51,6 +53,15 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
51
53
var2id[var.name ()] = v->id ();
52
54
}
53
55
56
+ // The variables in a SSA can only write once, so if a variable is written
57
+ // multiple times(quite common in our ProgramDesc design), multiple alias
58
+ // Nodes of this variable will be created, and each will just write once.
59
+
60
+ // An set that keep all the names of the variables(the original, not alias)
61
+ // that have been written(as outputs). Once an Op's output variable hit the
62
+ // set, it should create a new alias and update the global alias for this
63
+ // variable. And that make a Data Flow Graph a SSA.
64
+ std::unordered_set<Node *> unique_written_vars;
54
65
for (int i = 0 ; i < main_block.ops_size (); i++) {
55
66
const auto &op = main_block.ops (i);
56
67
auto *o = graph->nodes .Create (Node::Type::kFunction );
@@ -62,33 +73,33 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
62
73
o->SetPbMsg (op.SerializeAsString ());
63
74
64
75
// set inputs and outputs
65
- std::unordered_set<Node *> inlinks;
66
76
for (int j = 0 ; j < op.inputs_size (); j++) {
67
77
auto &in_var = op.inputs (j);
68
78
for (int k = 0 ; k < in_var.arguments_size (); k++) {
69
79
auto *in = graph->nodes .GetMutable (var2id.at (in_var.arguments (k)));
70
80
in->outlinks .push_back (o);
71
81
o->inlinks .push_back (in);
72
- inlinks.insert (in);
73
82
}
74
83
}
75
84
for (int j = 0 ; j < op.outputs_size (); j++) {
76
85
auto &out_var = op.outputs (j);
77
86
for (int k = 0 ; k < out_var.arguments_size (); k++) {
78
87
auto *out = graph->nodes .GetMutable (var2id[out_var.arguments (k)]);
79
- if (inlinks .count (out)) {
88
+ if (unique_written_vars .count (out)) {
80
89
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
81
90
auto *out_alias = graph->nodes .Create (Node::Type::kValue );
82
91
out_alias->SetName (out->name ());
83
92
out_alias->SetPbDesc (out->pb_desc ());
84
93
out_alias->SetPbMsg (out->pb_msg ());
85
- var2id[out_alias->name ()] = out_alias->id (); // update a -> a0
94
+ var2id[out_alias->name ()] =
95
+ out_alias->id (); // update variable's alias Node
86
96
LOG (INFO) << " loop found in graph, create SSA alias node ["
87
97
<< out_alias->repr () << " ] for [" << out->repr () << " ]" ;
88
98
out = out_alias;
89
99
}
90
100
out->inlinks .push_back (o);
91
101
o->outlinks .push_back (out);
102
+ unique_written_vars.insert (out);
92
103
}
93
104
}
94
105
}
0 commit comments