12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
#include " paddle/fluid/framework/details/multi_devices_graph_builder.h"
15
+ #include < fstream>
15
16
#include < utility>
16
17
#include " paddle/fluid/framework/details/broadcast_op_handle.h"
17
18
#include " paddle/fluid/framework/details/computation_op_handle.h"
18
19
#include " paddle/fluid/framework/details/reduce_op_handle.h"
20
+ #include " paddle/fluid/framework/details/rpc_op_handle.h"
19
21
#include " paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
20
- #include " paddle/fluid/framework/details/send_op_handle.h"
21
22
#include " paddle/fluid/framework/op_info.h"
22
23
#include " paddle/fluid/framework/scope.h"
23
24
28
29
#include < string>
29
30
#include < vector>
30
31
32
+ DEFINE_string (ssa_graph_path, " /tmp/ssa_graph.dot" ,
33
+ " the ssa graph path only print with GLOG_v=10,"
34
+ " default /tmp/graph.dot" );
35
+
31
36
namespace paddle {
32
37
namespace framework {
33
38
namespace details {
@@ -79,32 +84,66 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
79
84
}
80
85
}
81
86
82
- bool MultiDevSSAGraphBuilder::IsDistTrainOp (const OpDesc &op,
83
- OpDesc *send_op) const {
84
- if (send_op == nullptr ) {
87
+ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars (
88
+ const ProgramDesc &program) const {
89
+ std::vector<std::string> send_vars;
90
+ // since parameters are all in block 0,
91
+ // it's enough to only scan send ops in block 0
92
+ for (auto *op : program.Block (0 ).AllOps ()) {
93
+ // TODO(Yancey1989): use a graceful method to find send op,
94
+ // instead of the the hard code string
95
+ if (op->Type () == " send_vars" ) {
96
+ auto op_vars = op->InputArgumentNames ();
97
+ send_vars.reserve (send_vars.size () +
98
+ std::distance (op_vars.begin (), op_vars.end ()));
99
+ send_vars.insert (send_vars.end (), op_vars.begin (), op_vars.end ());
100
+ }
101
+ }
102
+ return send_vars;
103
+ }
104
+
105
+ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars (
106
+ const ProgramDesc &program) const {
107
+ std::vector<std::string> recv_vars;
108
+ for (auto *op : program.Block (0 ).AllOps ()) {
109
+ // TODO(Yancey1989): use a graceful method to find recv op,
110
+ // instead of the hard code string
111
+ if (op->Type () == " recv" ) {
112
+ auto op_vars = op->OutputArgumentNames ();
113
+ recv_vars.reserve (recv_vars.size () +
114
+ std::distance (op_vars.begin (), op_vars.end ()));
115
+ recv_vars.insert (recv_vars.end (), op_vars.begin (), op_vars.end ());
116
+ }
117
+ }
118
+ return recv_vars;
119
+ }
120
+
121
+ bool MultiDevSSAGraphBuilder::IsDistTrainOp (
122
+ const OpDesc &op, const std::vector<std::string> &send_vars,
123
+ const std::vector<std::string> &recv_vars) const {
124
+ if (send_vars.size () == 0 || recv_vars.size () == 0 ) {
85
125
return false ;
86
126
}
87
127
88
128
/* *
89
129
* Check any of opvars contains `.block` and in sendvars
90
130
*/
91
131
auto checker = [](const std::vector<std::string> &opvars,
92
- const std::vector<std::string> &sendvars ) -> bool {
132
+ const std::vector<std::string> &rpc_vars ) -> bool {
93
133
for (auto &var : opvars) {
134
+ // a variable name with the suffix `.block` means it's a splited
135
+ // variable by (DistributeTranspiler)
136
+ // [python/paddle/fluid/transpiler/distribute_transpiler.py]
94
137
if (var.find (" .block" ) != std::string::npos &&
95
- std::find (sendvars .begin (), sendvars .end (), var) != sendvars .end ()) {
138
+ std::find (rpc_vars .begin (), rpc_vars .end (), var) != rpc_vars .end ()) {
96
139
return true ;
97
140
}
98
141
}
99
142
return false ;
100
143
};
101
144
102
- if (op.Type () == " split" || op.Type () == " split_byref" ) {
103
- return checker (op.OutputArgumentNames (), send_op->InputArgumentNames ());
104
- } else if (op.Type () == " concat" ) {
105
- return checker (op.InputArgumentNames (), send_op->OutputArgumentNames ());
106
- }
107
- return false ;
145
+ return checker (op.OutputArgumentNames (), send_vars) ||
146
+ checker (op.InputArgumentNames (), recv_vars);
108
147
}
109
148
110
149
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build (
@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
123
162
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
124
163
places_.size ());
125
164
126
- // Find "send" op first for split is in front of send.
127
- OpDesc *send_op = GetSendOpDesc (program);
165
+ // find send/recv vars so that we can place the distributed training
166
+ // realted op in the place 0
167
+ auto send_vars = FindDistTrainSendVars (program);
168
+ auto recv_vars = FindDistTrainRecvVars (program);
128
169
129
170
size_t cur_device_id = 0 ;
130
171
std::vector<std::unordered_set<std::string>> var_name_on_devices;
@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
134
175
135
176
bool is_forwarding = true ;
136
177
for (auto *op : program.Block (0 ).AllOps ()) {
137
- if (op->Type () == " send" ) {
138
- // append send op if program is distributed trainer main program.
178
+ if (boost::get<int >(
179
+ op->GetAttr (OpProtoAndCheckerMaker::OpRoleAttrName ())) ==
180
+ static_cast <int >(OpRole::kRPC )) {
181
+ // append rpc op if program is distributed trainer main program.
139
182
// always use the first device
140
- CreateSendOp (&result, *op);
141
- } else if (IsDistTrainOp (*op, send_op )) {
142
- CreateComputationalOps (&result, *op, 1 );
183
+ CreateRPCOp (&result, *op);
184
+ } else if (IsDistTrainOp (*op, send_vars, recv_vars )) {
185
+ CreateDistTrainOp (&result, *op);
143
186
} else if (IsScaleLossOp (*op)) {
144
187
// user can customize loss@grad if not use_default_grad_scale_
145
188
if (strategy_.gradient_scale_ !=
@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
218
261
AddOutputToLeafOps (&result);
219
262
220
263
if (VLOG_IS_ON (10 )) {
221
- std::ostringstream sout;
222
- PrintGraphviz (*graph, sout);
223
- VLOG (10 ) << sout.str ();
264
+ std::ofstream fout (FLAGS_ssa_graph_path);
265
+ PrintGraphviz (*graph, fout);
224
266
}
225
267
226
268
return std::unique_ptr<SSAGraph>(graph);
@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
270
312
CreateOpHandleIOs (result, op, dev_id);
271
313
}
272
314
273
- OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc (
274
- const ProgramDesc &program) const {
275
- for (auto *op : program.Block (0 ).AllOps ()) {
276
- if (op->Type () == " send" ) {
277
- return op;
278
- }
279
- }
280
- return nullptr ;
281
- }
282
315
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp (
283
316
SSAGraph *result, const std::string &og) const {
284
317
#ifdef PADDLE_WITH_CUDA
@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
401
434
return var;
402
435
}
403
436
404
- void MultiDevSSAGraphBuilder::CreateSendOp (SSAGraph *result,
405
- const OpDesc &op) const {
437
+ void MultiDevSSAGraphBuilder::ConnectOp (SSAGraph *result, OpHandleBase *op,
438
+ const std::string &prev_op_name) const {
439
+ for (auto &prev_op : result->ops_ ) {
440
+ if (prev_op->Name () == prev_op_name) {
441
+ auto *dep_var = new DummyVarHandle ();
442
+ prev_op->AddOutput (dep_var);
443
+ result->dep_vars_ .emplace (dep_var);
444
+ op->AddInput (dep_var);
445
+ }
446
+ }
447
+ }
448
+
449
+ void MultiDevSSAGraphBuilder::CreateDistTrainOp (SSAGraph *result,
450
+ const OpDesc &op) const {
451
+ CreateComputationalOp (result, op, 0 );
452
+ if (op.Type () == " concat" ) {
453
+ ConnectOp (result, result->ops_ .back ().get (), " fetch_barrier" );
454
+ }
455
+ }
456
+
457
+ void MultiDevSSAGraphBuilder::CreateRPCOp (SSAGraph *result,
458
+ const OpDesc &op) const {
406
459
auto &p = places_[0 ];
407
460
auto *s = local_scopes_[0 ];
408
- // FIXME(wuyi): send op always copy from GPU 0
409
- result->ops_ .emplace_back (new SendOpHandle (op, s, p));
410
- // Create inputs for output on original place and no ssa output
411
- // is created for send op.
461
+ result->ops_ .emplace_back (new RPCOpHandle (op, s, p, op.Type ()));
462
+
463
+ if (op.Type () == " send_barrier" ) {
464
+ ConnectOp (result, result->ops_ .back ().get (), " send_vars" );
465
+ } else if (op.Type () == " recv" ) {
466
+ ConnectOp (result, result->ops_ .back ().get (), " send_barrier" );
467
+ } else if (op.Type () == " fetch_barrier" ) {
468
+ ConnectOp (result, result->ops_ .back ().get (), " recv" );
469
+ } else if (op.Type () == " send_vars" ) {
470
+ // do nothing
471
+ } else {
472
+ PADDLE_THROW (
473
+ " rpc op should be in ["
474
+ " send_vars, send_barrier. recv, fetch_barrier]" );
475
+ }
476
+
477
+ // TODO(Yancey1989): schedule rpc op on different place may
478
+ // increate throughput
412
479
CreateOpHandleIOs (result, op, 0 );
413
480
}
414
481
0 commit comments