Skip to content

Commit 25706d0

Browse files
committed
properly set up dep of concat and fetch_bar
1 parent 398cfb4 commit 25706d0

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

paddle/fluid/framework/ir/graph.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,68 @@ namespace paddle {
2424
namespace framework {
2525
namespace ir {
2626

27+
std::vector<std::string> FindDistTrainSendVars(
28+
const std::vector<ir::Node *> &nodes) {
29+
std::vector<std::string> send_vars;
30+
// since parameters are all in block 0,
31+
// it's enough to only scan send ops in block 0
32+
for (auto &node : nodes) {
33+
auto op_vars = node->Op()->InputArgumentNames();
34+
send_vars.reserve(send_vars.size() +
35+
std::distance(op_vars.begin(), op_vars.end()));
36+
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
37+
}
38+
return send_vars;
39+
}
40+
41+
std::vector<std::string> FindDistTrainRecvVars(
42+
const std::vector<ir::Node *> &nodes) {
43+
std::vector<std::string> recv_vars;
44+
for (auto &node : nodes) {
45+
auto op_vars = node->Op()->OutputArgumentNames();
46+
recv_vars.reserve(recv_vars.size() +
47+
std::distance(op_vars.begin(), op_vars.end()));
48+
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
49+
}
50+
return recv_vars;
51+
}
52+
53+
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
54+
const std::vector<std::string> &recv_vars) {
55+
if (send_vars.size() == 0 || recv_vars.size() == 0) {
56+
return false;
57+
}
58+
59+
/**
60+
* Check any of opvars contains `.block` and in sendvars
61+
*/
62+
auto checker = [](const std::vector<std::string> &opvars,
63+
const std::vector<std::string> &rpc_vars) -> bool {
64+
for (auto &var : opvars) {
65+
// a variable name with the suffix `.block` means it's a splited
66+
// variable by (DistributeTranspiler)
67+
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
68+
if (var.find(".block") != std::string::npos &&
69+
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
70+
return true;
71+
}
72+
}
73+
return false;
74+
};
75+
76+
std::vector<std::string> input_var_names;
77+
std::vector<std::string> output_var_names;
78+
for (ir::Node *input : node->inputs) {
79+
input_var_names.push_back(input->Name());
80+
}
81+
for (ir::Node *output : node->outputs) {
82+
output_var_names.push_back(output->Name());
83+
}
84+
85+
return checker(output_var_names, send_vars) ||
86+
checker(input_var_names, recv_vars);
87+
}
88+
2789
Graph::Graph(const ProgramDesc &program) : program_(program) {
2890
VLOG(3) << "block in program:" << program_.Size();
2991
std::unordered_map<std::string, VarDesc *> all_vars;
@@ -104,6 +166,21 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
104166
dep_var->outputs.push_back(fetch_bar);
105167
}
106168
}
169+
170+
std::vector<std::string> send_vars = FindDistTrainSendVars(send_ops);
171+
std::vector<std::string> recv_vars = FindDistTrainRecvVars(recv_ops);
172+
for (ir::Node *node : Nodes()) {
173+
if (IsDistTrainOp(node, send_vars, recv_vars)) {
174+
if (fetch_bar && node->Name() == "concat") {
175+
ir::Node *dep_var = CreateControlDepVar();
176+
fetch_bar->outputs.push_back(dep_var);
177+
dep_var->inputs.push_back(fetch_bar);
178+
node->inputs.push_back(dep_var);
179+
dep_var->outputs.push_back(node);
180+
}
181+
}
182+
}
183+
107184
/**
108185
* We only handle write after read(WAR), since it should not have a write
109186
* after write in program. If there are write after write operators, we need

0 commit comments

Comments
 (0)