Skip to content

Commit 7ae73e3

Browse files
authored
Merge pull request #12432 from Superjomn/fea/analysis-ssa
inference analysis supports SSA
2 parents 271b724 + 15c2f1a commit 7ae73e3

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

paddle/fluid/inference/analysis/data_flow_graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ namespace analysis {
3636

3737
/*
3838
* DataFlowGraph - A container of Value and Function Nodes.
39+
*
40+
* This is the base graph for any other type of graphs, such as SSA or CFG.
3941
*/
4042
struct DataFlowGraph {
4143
NodeMap nodes;

paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
4040
PADDLE_ENFORCE(graph);
4141
PADDLE_ENFORCE(desc_);
4242
// insert vars
43+
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
44+
// will keep updating to its latest alias during the graph-building.
4345
std::unordered_map<std::string, size_t> var2id;
4446
auto &main_block = desc_->blocks(framework::kRootBlockIndex);
4547
for (int i = 0; i < main_block.vars_size(); i++) {
@@ -51,6 +53,15 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
5153
var2id[var.name()] = v->id();
5254
}
5355

56+
// The variables in a SSA can only write once, so if a variable is written
57+
// multiple times(quite common in our ProgramDesc design), multiple alias
58+
// Nodes of this variable will be created, and each will just write once.
59+
60+
// An set that keep all the names of the variables(the original, not alias)
61+
// that have been written(as outputs). Once an Op's output variable hit the
62+
// set, it should create a new alias and update the global alias for this
63+
// variable. And that make a Data Flow Graph a SSA.
64+
std::unordered_set<Node *> unique_written_vars;
5465
for (int i = 0; i < main_block.ops_size(); i++) {
5566
const auto &op = main_block.ops(i);
5667
auto *o = graph->nodes.Create(Node::Type::kFunction);
@@ -62,33 +73,33 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
6273
o->SetPbMsg(op.SerializeAsString());
6374

6475
// set inputs and outputs
65-
std::unordered_set<Node *> inlinks;
6676
for (int j = 0; j < op.inputs_size(); j++) {
6777
auto &in_var = op.inputs(j);
6878
for (int k = 0; k < in_var.arguments_size(); k++) {
6979
auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k)));
7080
in->outlinks.push_back(o);
7181
o->inlinks.push_back(in);
72-
inlinks.insert(in);
7382
}
7483
}
7584
for (int j = 0; j < op.outputs_size(); j++) {
7685
auto &out_var = op.outputs(j);
7786
for (int k = 0; k < out_var.arguments_size(); k++) {
7887
auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]);
79-
if (inlinks.count(out)) {
88+
if (unique_written_vars.count(out)) {
8089
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
8190
auto *out_alias = graph->nodes.Create(Node::Type::kValue);
8291
out_alias->SetName(out->name());
8392
out_alias->SetPbDesc(out->pb_desc());
8493
out_alias->SetPbMsg(out->pb_msg());
85-
var2id[out_alias->name()] = out_alias->id(); // update a -> a0
94+
var2id[out_alias->name()] =
95+
out_alias->id(); // update variable's alias Node
8696
LOG(INFO) << "loop found in graph, create SSA alias node ["
8797
<< out_alias->repr() << "] for [" << out->repr() << "]";
8898
out = out_alias;
8999
}
90100
out->inlinks.push_back(o);
91101
o->outlinks.push_back(out);
102+
unique_written_vars.insert(out);
92103
}
93104
}
94105
}

paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace inference {
3030
namespace analysis {
3131

3232
/*
33-
* Transform a FluidDesc to a data flow graph.
33+
* Transform a FluidDesc to a SSA.
3434
*/
3535
class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
3636
public:

0 commit comments

Comments
 (0)