@@ -24,6 +24,68 @@ namespace paddle {
24
24
namespace framework {
25
25
namespace ir {
26
26
27
+ std::vector<std::string> FindDistTrainSendVars (
28
+ const std::vector<ir::Node *> &nodes) {
29
+ std::vector<std::string> send_vars;
30
+ // since parameters are all in block 0,
31
+ // it's enough to only scan send ops in block 0
32
+ for (auto &node : nodes) {
33
+ auto op_vars = node->Op ()->InputArgumentNames ();
34
+ send_vars.reserve (send_vars.size () +
35
+ std::distance (op_vars.begin (), op_vars.end ()));
36
+ send_vars.insert (send_vars.end (), op_vars.begin (), op_vars.end ());
37
+ }
38
+ return send_vars;
39
+ }
40
+
41
+ std::vector<std::string> FindDistTrainRecvVars (
42
+ const std::vector<ir::Node *> &nodes) {
43
+ std::vector<std::string> recv_vars;
44
+ for (auto &node : nodes) {
45
+ auto op_vars = node->Op ()->OutputArgumentNames ();
46
+ recv_vars.reserve (recv_vars.size () +
47
+ std::distance (op_vars.begin (), op_vars.end ()));
48
+ recv_vars.insert (recv_vars.end (), op_vars.begin (), op_vars.end ());
49
+ }
50
+ return recv_vars;
51
+ }
52
+
53
+ bool IsDistTrainOp (ir::Node *node, const std::vector<std::string> &send_vars,
54
+ const std::vector<std::string> &recv_vars) {
55
+ if (send_vars.size () == 0 || recv_vars.size () == 0 ) {
56
+ return false ;
57
+ }
58
+
59
+ /* *
60
+ * Check any of opvars contains `.block` and in sendvars
61
+ */
62
+ auto checker = [](const std::vector<std::string> &opvars,
63
+ const std::vector<std::string> &rpc_vars) -> bool {
64
+ for (auto &var : opvars) {
65
+ // a variable name with the suffix `.block` means it's a splited
66
+ // variable by (DistributeTranspiler)
67
+ // [python/paddle/fluid/transpiler/distribute_transpiler.py]
68
+ if (var.find (" .block" ) != std::string::npos &&
69
+ std::find (rpc_vars.begin (), rpc_vars.end (), var) != rpc_vars.end ()) {
70
+ return true ;
71
+ }
72
+ }
73
+ return false ;
74
+ };
75
+
76
+ std::vector<std::string> input_var_names;
77
+ std::vector<std::string> output_var_names;
78
+ for (ir::Node *input : node->inputs ) {
79
+ input_var_names.push_back (input->Name ());
80
+ }
81
+ for (ir::Node *output : node->outputs ) {
82
+ output_var_names.push_back (output->Name ());
83
+ }
84
+
85
+ return checker (output_var_names, send_vars) ||
86
+ checker (input_var_names, recv_vars);
87
+ }
88
+
27
89
Graph::Graph (const ProgramDesc &program) : program_(program) {
28
90
VLOG (3 ) << " block in program:" << program_.Size ();
29
91
std::unordered_map<std::string, VarDesc *> all_vars;
@@ -104,6 +166,21 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
104
166
dep_var->outputs .push_back (fetch_bar);
105
167
}
106
168
}
169
+
170
+ std::vector<std::string> send_vars = FindDistTrainSendVars (send_ops);
171
+ std::vector<std::string> recv_vars = FindDistTrainRecvVars (recv_ops);
172
+ for (ir::Node *node : Nodes ()) {
173
+ if (IsDistTrainOp (node, send_vars, recv_vars)) {
174
+ if (fetch_bar && node->Name () == " concat" ) {
175
+ ir::Node *dep_var = CreateControlDepVar ();
176
+ fetch_bar->outputs .push_back (dep_var);
177
+ dep_var->inputs .push_back (fetch_bar);
178
+ node->inputs .push_back (dep_var);
179
+ dep_var->outputs .push_back (node);
180
+ }
181
+ }
182
+ }
183
+
107
184
/* *
108
185
* We only handle write after read(WAR), since it should not have a write
109
186
* after write in program. If there are write after write operators, we need
0 commit comments