12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
#include " paddle/fluid/framework/details/multi_devices_graph_builder.h"
15
- #include < fstream>
16
15
#include < utility>
17
16
#include " paddle/fluid/framework/details/broadcast_op_handle.h"
18
17
#include " paddle/fluid/framework/details/computation_op_handle.h"
@@ -79,31 +78,63 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
79
78
CreateOpOutput (result, op_handle, each_var_name, p, place_id);
80
79
}
81
80
}
82
- bool MultiDevSSAGraphBuilder::IsDistTrainOp (const OpDesc &op,
83
- OpDesc *send_op) const {
84
- if (send_op == nullptr ) {
81
+
82
+ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars (
83
+ const ProgramDesc &program) const {
84
+ std::vector<std::string> send_vars;
85
+ for (auto *op : program.Block (0 ).AllOps ()) {
86
+ if (op->Type () == " send_vars" || op->Type () == " send" ) {
87
+ auto op_vars = op->InputArgumentNames ();
88
+ send_vars.reserve (send_vars.size () +
89
+ std::distance (op_vars.begin (), op_vars.end ()));
90
+ send_vars.insert (send_vars.end (), op_vars.begin (), op_vars.end ());
91
+ }
92
+ }
93
+ return send_vars;
94
+ }
95
+
96
+ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars (
97
+ const ProgramDesc &program) const {
98
+ std::vector<std::string> recv_vars;
99
+ for (auto *op : program.Block (0 ).AllOps ()) {
100
+ if (op->Type () == " recv" || op->Type () == " send" ) {
101
+ auto op_vars = op->OutputArgumentNames ();
102
+ recv_vars.reserve (recv_vars.size () +
103
+ std::distance (op_vars.begin (), op_vars.end ()));
104
+ recv_vars.insert (recv_vars.end (), op_vars.begin (), op_vars.end ());
105
+ }
106
+ }
107
+ return recv_vars;
108
+ }
109
+
110
+ bool MultiDevSSAGraphBuilder::IsDistTrainOp (
111
+ const OpDesc &op, const std::vector<std::string> &send_vars,
112
+ const std::vector<std::string> &recv_vars) const {
113
+ if (send_vars.size () == 0 || recv_vars.size () == 0 ) {
85
114
return false ;
86
115
}
87
116
88
117
/* *
89
118
* Check any of opvars contains `.block` and in sendvars
90
119
*/
91
120
auto checker = [](const std::vector<std::string> &opvars,
92
- const std::vector<std::string> &sendvars ) -> bool {
121
+ const std::vector<std::string> &rpc_vars ) -> bool {
93
122
for (auto &var : opvars) {
94
123
if (var.find (" .block" ) != std::string::npos &&
95
- std::find (sendvars .begin (), sendvars .end (), var) != sendvars .end ()) {
124
+ std::find (rpc_vars .begin (), rpc_vars .end (), var) != rpc_vars .end ()) {
96
125
return true ;
97
126
}
98
127
}
99
128
return false ;
100
129
};
101
130
102
- if (op.Type () == " split" || op.Type () == " split_byref" ) {
103
- return checker (op.OutputArgumentNames (), send_op->InputArgumentNames ());
131
+ if (op.Type () == " split" || op.Type () == " split_byref" ||
132
+ op.Type () == " split_selected_rows" ) {
133
+ return checker (op.OutputArgumentNames (), send_vars);
104
134
} else if (op.Type () == " concat" ) {
105
- return checker (op.InputArgumentNames (), send_op-> OutputArgumentNames () );
135
+ return checker (op.InputArgumentNames (), recv_vars );
106
136
}
137
+
107
138
return false ;
108
139
}
109
140
@@ -132,8 +163,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
132
163
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
133
164
places_.size ());
134
165
135
- // Find "send" op first for split is in front of send.
136
- OpDesc *send_op = GetSendOpDesc (program);
166
+ // find send/recv vars so that we can place the distributed training
167
+ // realted op in the place 0
168
+ auto send_vars = FindDistTrainSendVars (program);
169
+ auto recv_vars = FindDistTrainRecvVars (program);
137
170
138
171
size_t cur_device_id = 0 ;
139
172
std::vector<std::unordered_set<std::string>> var_name_on_devices;
@@ -147,8 +180,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
147
180
// append rpc op if program is distributed trainer main program.
148
181
// always use the first device
149
182
CreateRPCOp (&result, *op);
150
- } else if (IsDistTrainOp (*op, send_op)) {
151
- CreateComputationalOps (&result, *op, 1 );
183
+ } else if (IsDistTrainOp (*op, send_vars, recv_vars)) {
184
+ // CreateComputationalOps(&result, *op, 1);
185
+ CreateComputationalOp (&result, *op, 0 );
152
186
} else if (IsScaleLossOp (*op)) {
153
187
// user can customize loss@grad if not use_default_grad_scale_
154
188
if (strategy_.gradient_scale_ !=
@@ -213,9 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
213
247
AddOutputToLeafOps (&result);
214
248
215
249
if (VLOG_IS_ON (10 )) {
216
- std::string filename = " /tmp/graph " ;
217
- std::ofstream fout (filename );
218
- PrintGraphviz (*graph, fout );
250
+ std::ostringstream sout ;
251
+ PrintGraphviz (*graph, sout );
252
+ VLOG ( 10 ) << sout. str ( );
219
253
}
220
254
221
255
return std::unique_ptr<SSAGraph>(graph);
@@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
274
308
}
275
309
return nullptr ;
276
310
}
311
+
277
312
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp (
278
313
SSAGraph *result, const std::string &og) const {
279
314
#ifdef PADDLE_WITH_CUDA
@@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
396
431
return var;
397
432
}
398
433
399
- void MultiDevSSAGraphBuilder::ConnectOp (SSAGraph *result,
400
- std::string op_name ) const {
434
+ void MultiDevSSAGraphBuilder::ConnectOp (SSAGraph *result, OpHandleBase *op,
435
+ const std::string &prev_op_name ) const {
401
436
for (auto &prev_op : result->ops_ ) {
402
- if (prev_op->Name () == op_name ) {
437
+ if (prev_op->Name () == prev_op_name ) {
403
438
auto *dep_var = new DummyVarHandle ();
404
439
prev_op->AddOutput (dep_var);
405
440
result->dep_vars_ .emplace (dep_var);
406
- result-> ops_ . back (). get () ->AddInput (dep_var);
441
+ op ->AddInput (dep_var);
407
442
}
408
443
}
409
444
}
@@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
412
447
const OpDesc &op) const {
413
448
auto &p = places_[0 ];
414
449
auto *s = local_scopes_[0 ];
415
- VLOG (3 ) << " create rpc op: " << op.Type ();
416
450
result->ops_ .emplace_back (new RPCOpHandle (op, s, p, op.Type ()));
451
+
417
452
if (op.Type () == " send_barrier" ) {
418
- ConnectOp (result, " send_vars" );
453
+ ConnectOp (result, result-> ops_ . back (). get (), " send_vars" );
419
454
} else if (op.Type () == " recv" ) {
420
- ConnectOp (result, " send_barrier" );
455
+ ConnectOp (result, result-> ops_ . back (). get (), " send_barrier" );
421
456
} else if (op.Type () == " fetch_barrier" ) {
422
- ConnectOp (result, " recv" );
457
+ ConnectOp (result, result-> ops_ . back (). get (), " recv" );
423
458
} else if (op.Type () == " send" || op.Type () == " send_vars" ) {
424
459
// do nothing
425
460
} else {
@@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
429
464
}
430
465
431
466
// FIXME(wuyi): send op always copy from GPU 0
432
- // result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
433
467
// Create inputs for output on original place and no ssa output
434
468
// is created for send op.
435
469
CreateOpHandleIOs (result, op, 0 );
0 commit comments