Skip to content

Commit 9c9e28b

Browse files
committed
fix program to graph
1 parent 64eaa4c commit 9c9e28b

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
210210
size_t cur_device_id = 0;
211211
bool is_forwarding = true;
212212

213-
// TODO(panyx0718): FIXME: nodes should be sorted by "program" order.
213+
// NOTE: Currently, passes before SSAGraphBuilder cannot reorder
214+
// forward, backward nodes. E.g. you can't append an forward node
215+
// at the end of the node list.
216+
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
214217
for (auto &node : nodes) {
215218
if (node->NodeType() != ir::Node::Type::kOperation) continue;
216219
if (boost::get<int>(

paddle/fluid/framework/ir/graph.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,43 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace framework {
2121

22+
// NOTE(paddle-dev): This graph contains circle.
2223
Graph::Graph(const ProgramDesc &program) : program_(program) {
2324
std::unordered_map<std::string, VarDesc *> all_vars;
2425
for (auto *var : program.Block(0).AllVars()) {
2526
all_vars.emplace(var->Name(), var);
2627
}
2728

29+
std::map<std::string, ir::Node *> var_nodes;
2830
for (auto *op : program.Block(0).AllOps()) {
2931
ir::Node *node = CreateOpNode(op);
3032

3133
for (auto &each_var_name : op->InputArgumentNames()) {
3234
ir::Node *var = nullptr;
33-
if (all_vars.count(each_var_name) != 0) {
35+
if (var_nodes.find(each_var_name) != var_nodes.end()) {
36+
var = var_nodes.at(each_var_name);
37+
} else if (all_vars.count(each_var_name) != 0) {
3438
var = CreateVarNode(all_vars.at(each_var_name));
39+
var_nodes[each_var_name] = var;
3540
} else {
3641
// TODO(paddle-dev): Seems some assumption doesn't hold?
3742
LOG(ERROR) << op->Type()
3843
<< " input var not in all_var list: " << each_var_name;
3944
var = CreateEmptyNode(each_var_name);
45+
var_nodes[each_var_name] = var;
4046
}
4147
node->inputs.push_back(var);
4248
var->outputs.push_back(node);
4349
}
4450

4551
for (auto &each_var_name : op->OutputArgumentNames()) {
46-
ir::Node *var = CreateVarNode(all_vars.at(each_var_name));
52+
ir::Node *var = nullptr;
53+
if (var_nodes.find(each_var_name) != var_nodes.end()) {
54+
var = var_nodes.at(each_var_name);
55+
} else {
56+
var = CreateVarNode(all_vars.at(each_var_name));
57+
var_nodes[each_var_name] = var;
58+
}
4759
node->outputs.push_back(var);
4860
var->inputs.push_back(node);
4961
}

0 commit comments

Comments
 (0)