Skip to content

Commit d92a75b

Browse files
author
Yancey
authored
Merge pull request #10550 from Yancey1989/overlap_send_op
overlap send ops and backward ops
2 parents 7655e4c + 5d7c58e commit d92a75b

27 files changed

+559
-340
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
33
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
44
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
55
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
6-
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
6+
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
77

88
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
99
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
@@ -26,7 +26,7 @@ endif()
2626
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
2727

2828
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
29-
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
29+
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
3030

3131
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
3232
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 104 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
15+
#include <fstream>
1516
#include <utility>
1617
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
1718
#include "paddle/fluid/framework/details/computation_op_handle.h"
1819
#include "paddle/fluid/framework/details/reduce_op_handle.h"
20+
#include "paddle/fluid/framework/details/rpc_op_handle.h"
1921
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
20-
#include "paddle/fluid/framework/details/send_op_handle.h"
2122
#include "paddle/fluid/framework/op_info.h"
2223
#include "paddle/fluid/framework/scope.h"
2324

@@ -28,6 +29,10 @@
2829
#include <string>
2930
#include <vector>
3031

32+
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
33+
"the ssa graph path only print with GLOG_v=10,"
34+
"default /tmp/graph.dot");
35+
3136
namespace paddle {
3237
namespace framework {
3338
namespace details {
@@ -79,32 +84,66 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
7984
}
8085
}
8186

82-
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
83-
OpDesc *send_op) const {
84-
if (send_op == nullptr) {
87+
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
88+
const ProgramDesc &program) const {
89+
std::vector<std::string> send_vars;
90+
// since parameters are all in block 0,
91+
// it's enough to only scan send ops in block 0
92+
for (auto *op : program.Block(0).AllOps()) {
93+
// TODO(Yancey1989): use a graceful method to find send op,
94+
// instead of the the hard code string
95+
if (op->Type() == "send_vars") {
96+
auto op_vars = op->InputArgumentNames();
97+
send_vars.reserve(send_vars.size() +
98+
std::distance(op_vars.begin(), op_vars.end()));
99+
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
100+
}
101+
}
102+
return send_vars;
103+
}
104+
105+
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
106+
const ProgramDesc &program) const {
107+
std::vector<std::string> recv_vars;
108+
for (auto *op : program.Block(0).AllOps()) {
109+
// TODO(Yancey1989): use a graceful method to find recv op,
110+
// instead of the hard code string
111+
if (op->Type() == "recv") {
112+
auto op_vars = op->OutputArgumentNames();
113+
recv_vars.reserve(recv_vars.size() +
114+
std::distance(op_vars.begin(), op_vars.end()));
115+
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
116+
}
117+
}
118+
return recv_vars;
119+
}
120+
121+
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
122+
const OpDesc &op, const std::vector<std::string> &send_vars,
123+
const std::vector<std::string> &recv_vars) const {
124+
if (send_vars.size() == 0 || recv_vars.size() == 0) {
85125
return false;
86126
}
87127

88128
/**
89129
* Check any of opvars contains `.block` and in sendvars
90130
*/
91131
auto checker = [](const std::vector<std::string> &opvars,
92-
const std::vector<std::string> &sendvars) -> bool {
132+
const std::vector<std::string> &rpc_vars) -> bool {
93133
for (auto &var : opvars) {
134+
// a variable name with the suffix `.block` means it's a splited
135+
// variable by (DistributeTranspiler)
136+
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
94137
if (var.find(".block") != std::string::npos &&
95-
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
138+
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
96139
return true;
97140
}
98141
}
99142
return false;
100143
};
101144

102-
if (op.Type() == "split" || op.Type() == "split_byref") {
103-
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
104-
} else if (op.Type() == "concat") {
105-
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
106-
}
107-
return false;
145+
return checker(op.OutputArgumentNames(), send_vars) ||
146+
checker(op.InputArgumentNames(), recv_vars);
108147
}
109148

110149
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
123162
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
124163
places_.size());
125164

126-
// Find "send" op first for split is in front of send.
127-
OpDesc *send_op = GetSendOpDesc(program);
165+
// find send/recv vars so that we can place the distributed training
166+
// realted op in the place 0
167+
auto send_vars = FindDistTrainSendVars(program);
168+
auto recv_vars = FindDistTrainRecvVars(program);
128169

129170
size_t cur_device_id = 0;
130171
std::vector<std::unordered_set<std::string>> var_name_on_devices;
@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
134175

135176
bool is_forwarding = true;
136177
for (auto *op : program.Block(0).AllOps()) {
137-
if (op->Type() == "send") {
138-
// append send op if program is distributed trainer main program.
178+
if (boost::get<int>(
179+
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
180+
static_cast<int>(OpRole::kRPC)) {
181+
// append rpc op if program is distributed trainer main program.
139182
// always use the first device
140-
CreateSendOp(&result, *op);
141-
} else if (IsDistTrainOp(*op, send_op)) {
142-
CreateComputationalOps(&result, *op, 1);
183+
CreateRPCOp(&result, *op);
184+
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
185+
CreateDistTrainOp(&result, *op);
143186
} else if (IsScaleLossOp(*op)) {
144187
// user can customize loss@grad if not use_default_grad_scale_
145188
if (strategy_.gradient_scale_ !=
@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
218261
AddOutputToLeafOps(&result);
219262

220263
if (VLOG_IS_ON(10)) {
221-
std::ostringstream sout;
222-
PrintGraphviz(*graph, sout);
223-
VLOG(10) << sout.str();
264+
std::ofstream fout(FLAGS_ssa_graph_path);
265+
PrintGraphviz(*graph, fout);
224266
}
225267

226268
return std::unique_ptr<SSAGraph>(graph);
@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
270312
CreateOpHandleIOs(result, op, dev_id);
271313
}
272314

273-
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
274-
const ProgramDesc &program) const {
275-
for (auto *op : program.Block(0).AllOps()) {
276-
if (op->Type() == "send") {
277-
return op;
278-
}
279-
}
280-
return nullptr;
281-
}
282315
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
283316
SSAGraph *result, const std::string &og) const {
284317
#ifdef PADDLE_WITH_CUDA
@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
401434
return var;
402435
}
403436

404-
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
405-
const OpDesc &op) const {
437+
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
438+
const std::string &prev_op_name) const {
439+
for (auto &prev_op : result->ops_) {
440+
if (prev_op->Name() == prev_op_name) {
441+
auto *dep_var = new DummyVarHandle();
442+
prev_op->AddOutput(dep_var);
443+
result->dep_vars_.emplace(dep_var);
444+
op->AddInput(dep_var);
445+
}
446+
}
447+
}
448+
449+
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
450+
const OpDesc &op) const {
451+
CreateComputationalOp(result, op, 0);
452+
if (op.Type() == "concat") {
453+
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
454+
}
455+
}
456+
457+
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
458+
const OpDesc &op) const {
406459
auto &p = places_[0];
407460
auto *s = local_scopes_[0];
408-
// FIXME(wuyi): send op always copy from GPU 0
409-
result->ops_.emplace_back(new SendOpHandle(op, s, p));
410-
// Create inputs for output on original place and no ssa output
411-
// is created for send op.
461+
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
462+
463+
if (op.Type() == "send_barrier") {
464+
ConnectOp(result, result->ops_.back().get(), "send_vars");
465+
} else if (op.Type() == "recv") {
466+
ConnectOp(result, result->ops_.back().get(), "send_barrier");
467+
} else if (op.Type() == "fetch_barrier") {
468+
ConnectOp(result, result->ops_.back().get(), "recv");
469+
} else if (op.Type() == "send_vars") {
470+
// do nothing
471+
} else {
472+
PADDLE_THROW(
473+
"rpc op should be in ["
474+
"send_vars, send_barrier. recv, fetch_barrier]");
475+
}
476+
477+
// TODO(Yancey1989): schedule rpc op on different place may
478+
// increate throughput
412479
CreateOpHandleIOs(result, op, 0);
413480
}
414481

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6464

6565
bool IsScaleLossOp(const OpDesc &op) const;
6666

67-
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
67+
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
68+
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
6869

6970
/**
7071
* Is this operator as the end-point operator before/after send operator.
7172
*/
72-
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
73+
bool IsDistTrainOp(const OpDesc &op,
74+
const std::vector<std::string> &send_vars,
75+
const std::vector<std::string> &recv_vars) const;
76+
77+
std::vector<std::string> FindDistTrainSendVars(
78+
const ProgramDesc &program) const;
79+
80+
std::vector<std::string> FindDistTrainRecvVars(
81+
const ProgramDesc &program) const;
82+
83+
void ConnectOp(SSAGraph *result, OpHandleBase *op,
84+
const std::string &prev_op_name) const;
7385

7486
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
7587
size_t num_places) const;
@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
93105
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
94106
size_t src_dev_id) const;
95107

96-
/**
97-
* Get send op in the global block of program.
98-
* nullptr if not found.
99-
*/
100-
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
101-
102108
bool IsSparseGradient(
103109
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
104110
const std::string &og) const;

paddle/fluid/framework/details/send_op_handle.cc renamed to paddle/fluid/framework/details/rpc_op_handle.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,26 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/framework/details/send_op_handle.h"
15+
#include "paddle/fluid/framework/details/rpc_op_handle.h"
1616

1717
namespace paddle {
1818
namespace framework {
1919
namespace details {
2020

21-
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
22-
const Scope *local_scope,
23-
const platform::Place &place)
21+
RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
22+
const Scope *local_scope, const platform::Place &place,
23+
const std::string &name)
2424
: op_(framework::OpRegistry::CreateOp(op_desc)),
2525
local_scope_(local_scope),
26-
place_(place) {}
26+
place_(place),
27+
name_(name) {}
2728

28-
void SendOpHandle::RunImpl() {
29+
void RPCOpHandle::RunImpl() {
2930
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
3031
// Wait input done
3132
for (auto *in : inputs_) {
3233
auto &p = static_cast<VarHandle *>(in)->place_;
34+
// FIXME(Yancey1989): need a better solution instead of use DebugString()
3335
if (in->DebugString() == "dummy") { // HACK
3436
continue;
3537
}
@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() {
4345
op_->Run(*tmp_scope, place_);
4446
}
4547

46-
std::string SendOpHandle::Name() const { return "send"; }
48+
std::string RPCOpHandle::Name() const { return name_; }
4749
} // namespace details
4850
} // namespace framework
4951
} // namespace paddle

paddle/fluid/framework/details/send_op_handle.h renamed to paddle/fluid/framework/details/rpc_op_handle.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ namespace paddle {
2727
namespace framework {
2828
namespace details {
2929

30-
struct SendOpHandle : public OpHandleBase {
31-
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
32-
const platform::Place& place);
30+
struct RPCOpHandle : public OpHandleBase {
31+
RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
32+
const platform::Place& place, const std::string& name);
3333

3434
std::string Name() const override;
3535

@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase {
4444
std::unique_ptr<OperatorBase> op_;
4545
const Scope* local_scope_;
4646
const platform::Place& place_;
47+
const std::string name_;
4748
};
4849

4950
} // namespace details

paddle/fluid/framework/op_proto_maker.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
6666
.InEnum(
6767
{static_cast<int>(OpRole::kForward),
6868
static_cast<int>(OpRole::kBackward),
69-
static_cast<int>(OpRole::kOptimize),
69+
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
7070
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
7171
static_cast<int>(OpRole::kLoss) |
7272
static_cast<int>(OpRole::kBackward),

paddle/fluid/framework/op_proto_maker.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ enum class OpRole {
2424
kForward = 0x0000,
2525
kBackward = 0x0001,
2626
kOptimize = 0x0002,
27+
kRPC = 0x0003,
2728

2829
kLoss = 0x0100,
2930
// The default value of op's role. This should be only used for unittests and

paddle/fluid/inference/analysis/data_flow_graph_tester.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
3535

3636
GraphTraits<DataFlowGraph> trait(&dfg);
3737
auto nodes = trait.nodes();
38-
int count = 0;
38+
size_t count = 0;
3939
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
4040
LOG(INFO) << "visiting " << it->name();
4141
++count;
@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
4949
dfg.Build();
5050
GraphTraits<DataFlowGraph> trait(&dfg);
5151
auto nodes = trait.nodes_in_DFS();
52-
int count = 0;
52+
size_t count = 0;
5353
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
5454
LOG(INFO) << "visiting " << it->name();
5555
++count;

paddle/fluid/operators/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE)
200200
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
201201
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
202202
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
203+
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
203204
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
205+
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
204206
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
205207
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
206208
# listen_and_serv_op sum_op executor SERIAL)
@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE)
214216
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
215217
endif()
216218
else()
217-
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
219+
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
218220
endif()
219221

220222
op_library(cross_entropy_op DEPS cross_entropy)

0 commit comments

Comments
 (0)