Skip to content

Commit 9cc1eb4

Browse files
author
Yancey
authored
Merge pull request #11221 from Yancey1989/overlap_memcpy_with_dist
overlap rpc op memcpy in distributed training
2 parents bfe5dc6 + 7e6518e commit 9cc1eb4

File tree

5 files changed

+127
-62
lines changed

5 files changed

+127
-62
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 95 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5757
for (auto &p : params) {
5858
grad_names_.insert(GradVarName(p));
5959
}
60+
balance_vars_.resize(places_.size(), 0);
6061
}
6162

6263
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
140141
checker(op.InputArgumentNames(), recv_vars);
141142
}
142143

144+
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
145+
const std::vector<std::string> &var_names) const {
146+
int64_t numel_sum = 0;
147+
for (auto var_name : var_names) {
148+
auto var_desc = all_vars_.at(var_name);
149+
PADDLE_ENFORCE_NOT_NULL(var_desc);
150+
auto dim = framework::make_ddim(var_desc->GetShape());
151+
int64_t numel = framework::product(dim);
152+
PADDLE_ENFORCE_GT(numel, 0);
153+
numel_sum += numel;
154+
}
155+
156+
auto smallest =
157+
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
158+
size_t dev_id =
159+
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
160+
balance_vars_[dev_id] += numel_sum;
161+
return dev_id;
162+
}
163+
143164
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
144165
const ProgramDesc &program) const {
145-
std::unordered_map<std::string, VarDesc *> all_vars;
146166
for (auto *var : program.Block(0).AllVars()) {
147-
all_vars[var->Name()] = var;
167+
all_vars_.emplace(var->Name(), var);
148168
}
149169

150170
auto graph = new SSAGraph();
@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
161181
auto send_vars = FindDistTrainSendVars(program);
162182
auto recv_vars = FindDistTrainRecvVars(program);
163183

164-
std::vector<std::unordered_set<std::string>> var_name_on_devices;
165184
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
166-
var_name_on_devices.resize(places_.size());
167185
bcast_var_name_set.resize(places_.size());
168186

169187
size_t cur_device_id = 0;
170-
std::vector<int64_t> balance_grads(places_.size(), 0);
171-
172-
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
173-
auto var_desc = all_vars.at(g_name);
174-
PADDLE_ENFORCE_NOT_NULL(var_desc);
175-
auto dim = framework::make_ddim(var_desc->GetShape());
176-
int64_t numel = framework::product(dim);
177-
PADDLE_ENFORCE_GE(numel, 0);
178-
auto smallest =
179-
std::min_element(std::begin(balance_grads), std::end(balance_grads));
180-
size_t dev_id =
181-
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
182-
balance_grads[dev_id] += numel;
183-
return dev_id;
184-
};
185-
186188
bool is_forwarding = true;
189+
187190
for (auto *op : program.Block(0).AllOps()) {
188191
if (boost::get<int>(
189192
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
190193
static_cast<int>(OpRole::kRPC)) {
191-
// append rpc op if program is distributed trainer main program.
192-
// always use the first device
193194
CreateRPCOp(&result, *op);
194195
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
195196
CreateDistTrainOp(&result, *op);
@@ -201,13 +202,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
201202
}
202203
is_forwarding = false;
203204
} else {
204-
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
205+
int op_dev_id = GetOpDeviceID(*op);
205206
if (op_dev_id == -1) { // var on all device
206207
CreateComputationalOps(&result, *op, places_.size());
207208
} else {
208209
CreateComputationalOp(&result, *op, op_dev_id);
209210
for (auto &var_name : op->OutputArgumentNames()) {
210-
var_name_on_devices[op_dev_id].emplace(var_name);
211+
var_name_on_devices_.emplace(var_name, op_dev_id);
211212
}
212213
}
213214
if (!is_forwarding && places_.size() > 1) {
@@ -230,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
230231

231232
switch (strategy_.reduce_) {
232233
case BuildStrategy::ReduceStrategy::kReduce:
233-
cur_device_id = get_appropriate_dev(g_name);
234+
cur_device_id = GetAppropriateDeviceID({g_name});
234235
CreateReduceOp(&result, g_name, cur_device_id);
235-
var_name_on_devices[cur_device_id].emplace(g_name);
236+
var_name_on_devices_.emplace(g_name, cur_device_id);
236237
bcast_var_name_set[cur_device_id].emplace(p_name);
237238
break;
238239
case BuildStrategy::ReduceStrategy::kAllReduce:
239-
if (IsSparseGradient(all_vars, g_name)) {
240+
if (IsSparseGradient(g_name)) {
240241
CreateReduceOp(&result, g_name, 0);
241242
CreateBroadcastOp(&result, g_name, 0);
242243
} else {
@@ -273,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
273274
return std::unique_ptr<SSAGraph>(graph);
274275
}
275276

276-
bool MultiDevSSAGraphBuilder::IsSparseGradient(
277-
const std::unordered_map<std::string, VarDesc *> &all_vars,
278-
const std::string &og) const {
279-
PADDLE_ENFORCE(all_vars.count(og) != 0);
280-
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
277+
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
278+
PADDLE_ENFORCE(all_vars_.count(og) != 0);
279+
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
281280
return true;
282281
}
283282
return false;
@@ -363,24 +362,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
363362
return is_pg_once;
364363
}
365364

366-
int MultiDevSSAGraphBuilder::GetOpDeviceID(
367-
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
368-
const OpDesc &op) const {
365+
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
369366
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
370367
return -1;
371368
}
372369

373-
int var_dev_id = -1;
374-
for (auto &var_name : op.InputArgumentNames()) {
375-
if (var_dev_id != -1) break;
376-
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
377-
if (var_name_on_devices[i].count(var_name)) {
378-
var_dev_id = static_cast<int>(i);
379-
break;
380-
}
370+
for (auto &varname : op.InputArgumentNames()) {
371+
int dev_id = GetVarDeviceID(varname);
372+
if (dev_id != -1) {
373+
return dev_id;
381374
}
382375
}
383-
return var_dev_id;
376+
return -1;
377+
}
378+
379+
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
380+
auto got = var_name_on_devices_.find(varname);
381+
return got == var_name_on_devices_.end() ? -1 : got->second;
384382
}
385383

386384
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
@@ -463,16 +461,65 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
463461

464462
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
465463
const OpDesc &op) const {
466-
CreateComputationalOp(result, op, 0);
464+
int op_dev_id = -1;
465+
if (op.Type() == "split_byref") {
466+
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
467+
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
468+
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
469+
for (auto &varname : op.InputArgumentNames()) {
470+
var_name_on_devices_.emplace(varname, op_dev_id);
471+
}
472+
}
473+
for (auto &varname : op.OutputArgumentNames()) {
474+
var_name_on_devices_.emplace(varname, op_dev_id);
475+
}
476+
} else if (op.Type() == "concat") {
477+
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
478+
} else {
479+
PADDLE_ENFORCE(
480+
"the distribute training related op should be in [split_byref, "
481+
"concat].");
482+
}
483+
484+
PADDLE_ENFORCE(op_dev_id != -1,
485+
"can not find right place for distributed op: %s", op.Type());
486+
487+
CreateComputationalOp(result, op, op_dev_id);
467488
if (op.Type() == "concat") {
468489
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
469490
}
470491
}
471492

472493
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
473494
const OpDesc &op) const {
474-
result->ops_.emplace_back(
475-
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
495+
int op_dev_id = -1;
496+
if (op.Type() == "send") {
497+
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
498+
// the variable name which contains .block means it was splited by
499+
// split_byref op
500+
// so that we can balance the variable blocks to all the pserver instances.
501+
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
502+
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
503+
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
504+
for (auto &varname : op.InputArgumentNames()) {
505+
var_name_on_devices_.emplace(varname, op_dev_id);
506+
}
507+
}
508+
} else if (op.Type() == "recv") {
509+
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
510+
for (auto &varname : op.OutputArgumentNames()) {
511+
var_name_on_devices_.emplace(varname, op_dev_id);
512+
}
513+
} else {
514+
// send_barrier and fetch_barrier op can be scheduled on device 0
515+
op_dev_id = 0;
516+
}
517+
518+
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
519+
op.Type());
520+
521+
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
522+
op.Type(), places_[op_dev_id]));
476523

477524
if (op.Type() == "send_barrier") {
478525
ConnectOp(result, result->ops_.back().get(), "send");
@@ -488,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
488535
"send, send_barrier. recv, fetch_barrier]");
489536
}
490537

491-
// TODO(Yancey1989): schedule rpc op on different place may
492-
// increate throughput
493-
CreateOpHandleIOs(result, op, 0);
538+
CreateOpHandleIOs(result, op, op_dev_id);
494539
}
495540

496541
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4747
#endif
4848

4949
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
50+
int GetVarDeviceID(const std::string &varname) const;
5051

5152
private:
5253
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
53-
size_t place_id) const;
54+
size_t device_id) const;
5455

5556
private:
5657
std::string loss_var_name_;
@@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
9697
const std::string &og,
9798
std::unordered_set<std::string> *og_has_been_broadcast) const;
9899

99-
int GetOpDeviceID(
100-
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
101-
const OpDesc &op) const;
100+
int GetOpDeviceID(const OpDesc &op) const;
102101

103102
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
104103

105104
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
106105
size_t src_dev_id) const;
107106

108-
bool IsSparseGradient(
109-
const std::unordered_map<std::string, VarDesc *> &all_vars,
110-
const std::string &og) const;
107+
bool IsSparseGradient(const std::string &og) const;
108+
109+
size_t GetAppropriateDeviceID(
110+
const std::vector<std::string> &var_names) const;
111111

112112
private:
113113
BuildStrategy strategy_;
114+
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
115+
mutable std::unordered_map<std::string, int> var_name_on_devices_;
116+
mutable std::vector<int64_t> balance_vars_;
114117

115118
void SetCommunicationContext(OpHandleBase *op_handle,
116119
const platform::Place &p) const;

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class SSAGraphBuilder {
3030
SSAGraphBuilder() {}
3131
virtual ~SSAGraphBuilder() {}
3232
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
33+
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
3334

3435
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
3536

paddle/fluid/framework/parallel_executor.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor(
110110

111111
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
112112
// ncclOp
113-
114113
details::SSAGraphBuilderFactory builder_factory(
115114
member_->places_, loss_var_name, params, member_->local_scopes_,
116115
build_strategy);
@@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor(
122121
#endif
123122
}
124123

124+
builder_ = std::move(builder_factory.Create());
125125
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
126126
exec_strategy, member_->local_scopes_, places,
127-
builder_factory.Create()->Build(main_program)));
127+
builder_->Build(main_program)));
128128

129129
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
130130
exec_strategy, member_->local_scopes_, std::move(var_infos),
@@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor(
133133

134134
void ParallelExecutor::BCastParamsToGPUs(
135135
const std::unordered_set<std::string> &vars) const {
136-
auto *main_scope = member_->local_scopes_[0];
136+
// the the initialize bcast, all vars would be bcast from device(0), otherwise
137+
// bcast from the specified device.
138+
bool initialize = builder_.get() == nullptr ? true : false;
137139

138140
for (auto &var : vars) {
139-
auto *main_var = main_scope->FindVar(var);
141+
int var_dev_id =
142+
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
143+
if (!initialize && var_dev_id == -1) continue;
144+
145+
framework::Variable *main_var = nullptr;
146+
if (initialize) {
147+
main_var = member_->local_scopes_[0]->FindVar(var);
148+
} else {
149+
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
150+
}
151+
140152
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
141153
continue;
142154
}
@@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs(
151163
for (size_t i = 0; i < member_->places_.size(); ++i) {
152164
auto place = member_->places_[i];
153165
void *buffer;
154-
if (i == 0) {
166+
167+
if ((initialize && i == 0) || (!initialize && i == var_dev_id)) {
155168
buffer = const_cast<void *>(main_tensor.data<void>());
156169
} else {
157170
auto local_scope = member_->local_scopes_[i];

paddle/fluid/framework/parallel_executor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ limitations under the License. */
1919
#include <unordered_set>
2020
#include <vector>
2121
#include "paddle/fluid/framework/details/execution_strategy.h"
22+
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
2223
#include "paddle/fluid/framework/executor.h"
2324
#include "paddle/fluid/framework/op_info.h"
2425
#include "paddle/fluid/framework/program_desc.h"
2526
#include "paddle/fluid/framework/scope.h"
2627
#include "paddle/fluid/framework/tensor.h"
2728
#include "paddle/fluid/platform/device_context.h"
29+
2830
namespace paddle {
2931
namespace framework {
3032

@@ -68,6 +70,7 @@ class ParallelExecutor {
6870

6971
private:
7072
ParallelExecutorPrivate *member_;
73+
std::unique_ptr<details::SSAGraphBuilder> builder_;
7174
};
7275

7376
} // namespace framework

0 commit comments

Comments
 (0)