Skip to content

Commit f52d78d

Browse files
committed
update by comment
1 parent 6d752ba commit f52d78d

File tree

2 files changed

+93
-83
lines changed

2 files changed

+93
-83
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 85 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5757
for (auto &p : params) {
5858
grad_names_.insert(GradVarName(p));
5959
}
60+
balance_vars_.resize(places_.size(), 0);
6061
}
6162

6263
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
140141
checker(op.InputArgumentNames(), recv_vars);
141142
}
142143

144+
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
145+
const std::vector<std::string> &var_names) const {
146+
int64_t numel_sum = 0;
147+
for (auto var_name : var_names) {
148+
auto var_desc = all_vars_.at(var_name);
149+
PADDLE_ENFORCE_NOT_NULL(var_desc);
150+
auto dim = framework::make_ddim(var_desc->GetShape());
151+
int64_t numel = framework::product(dim);
152+
PADDLE_ENFORCE_GT(numel, 0);
153+
numel_sum += numel;
154+
}
155+
156+
auto smallest =
157+
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
158+
size_t dev_id =
159+
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
160+
balance_vars_[dev_id] += numel_sum;
161+
return dev_id;
162+
}
163+
143164
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
144165
const ProgramDesc &program) const {
145-
std::unordered_map<std::string, VarDesc *> all_vars;
146166
for (auto *var : program.Block(0).AllVars()) {
147-
all_vars[var->Name()] = var;
167+
all_vars_.emplace(var->Name(), var);
148168
}
149169

150170
auto graph = new SSAGraph();
@@ -165,71 +185,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
165185
bcast_var_name_set.resize(places_.size());
166186

167187
size_t cur_device_id = 0;
168-
std::vector<int64_t> balance_grads(places_.size(), 0);
169-
170-
auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
171-
int64_t numel_all = 0;
172-
for (auto var_name : var_names) {
173-
auto var_desc = all_vars.at(var_name);
174-
PADDLE_ENFORCE_NOT_NULL(var_desc);
175-
auto dim = framework::make_ddim(var_desc->GetShape());
176-
int64_t numel = framework::product(dim);
177-
PADDLE_ENFORCE_GT(numel, 0);
178-
numel_all += numel;
179-
}
180-
181-
auto smallest =
182-
std::min_element(std::begin(balance_grads), std::end(balance_grads));
183-
size_t dev_id =
184-
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
185-
balance_grads[dev_id] += numel_all;
186-
return dev_id;
187-
};
188-
189188
bool is_forwarding = true;
190189

191190
for (auto *op : program.Block(0).AllOps()) {
192191
if (boost::get<int>(
193192
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
194193
static_cast<int>(OpRole::kRPC)) {
195-
// append rpc op if program is distributed trainer main program.
196-
// always use the first device
197-
if (op->Type() == "send_vars") {
198-
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
199-
if (op_dev_id == -1) {
200-
op_dev_id = get_appropriate_dev(op->InputArgumentNames());
201-
for (auto &varname : op->InputArgumentNames()) {
202-
var_name_on_devices_.emplace(varname, op_dev_id);
203-
}
204-
}
205-
CreateRPCOp(&result, *op, op_dev_id);
206-
} else if (op->Type() == "recv") {
207-
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
208-
for (auto &varname : op->OutputArgumentNames()) {
209-
var_name_on_devices_.emplace(varname, op_dev_id);
210-
}
211-
CreateRPCOp(&result, *op, op_dev_id);
212-
} else {
213-
// send_barrier and fetch_barrier op would run on device 0
214-
CreateRPCOp(&result, *op, 0);
215-
}
194+
CreateRPCOp(&result, *op);
216195
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
217-
if (op->Type() == "split_byref") {
218-
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
219-
for (auto &varname : op->OutputArgumentNames()) {
220-
var_name_on_devices_.emplace(varname, op_dev_id);
221-
}
222-
CreateDistTrainOp(&result, *op, op_dev_id);
223-
} else if (op->Type() == "concat") {
224-
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
225-
PADDLE_ENFORCE(op_dev_id != -1,
226-
"can not find right place to concatenate received var.");
227-
CreateDistTrainOp(&result, *op, op_dev_id);
228-
} else {
229-
PADDLE_ENFORCE(
230-
"the distribute training related op should be in [split_byref, "
231-
"concat].");
232-
}
196+
CreateDistTrainOp(&result, *op);
233197
} else if (IsScaleLossOp(*op)) {
234198
// user can customize loss@grad if not use_default_grad_scale_
235199
if (strategy_.gradient_scale_ !=
@@ -267,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
267231

268232
switch (strategy_.reduce_) {
269233
case BuildStrategy::ReduceStrategy::kReduce:
270-
cur_device_id = get_appropriate_dev({g_name});
234+
cur_device_id = GetAppropriateDeviceID({g_name});
271235
CreateReduceOp(&result, g_name, cur_device_id);
272236
var_name_on_devices_.emplace(g_name, cur_device_id);
273237
bcast_var_name_set[cur_device_id].emplace(p_name);
274238
break;
275239
case BuildStrategy::ReduceStrategy::kAllReduce:
276-
if (IsSparseGradient(all_vars, g_name)) {
240+
if (IsSparseGradient(g_name)) {
277241
CreateReduceOp(&result, g_name, 0);
278242
CreateBroadcastOp(&result, g_name, 0);
279243
} else {
@@ -310,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
310274
return std::unique_ptr<SSAGraph>(graph);
311275
}
312276

313-
bool MultiDevSSAGraphBuilder::IsSparseGradient(
314-
const std::unordered_map<std::string, VarDesc *> &all_vars,
315-
const std::string &og) const {
316-
PADDLE_ENFORCE(all_vars.count(og) != 0);
317-
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
277+
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
278+
PADDLE_ENFORCE(all_vars_.count(og) != 0);
279+
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
318280
return true;
319281
}
320282
return false;
@@ -498,18 +460,66 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
498460
}
499461

500462
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
501-
const OpDesc &op,
502-
int place_id) const {
503-
CreateComputationalOp(result, op, place_id);
463+
const OpDesc &op) const {
464+
int op_dev_id = -1;
465+
if (op.Type() == "split_byref") {
466+
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
467+
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
468+
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
469+
for (auto &varname : op.InputArgumentNames()) {
470+
var_name_on_devices_.emplace(varname, op_dev_id);
471+
}
472+
}
473+
for (auto &varname : op.OutputArgumentNames()) {
474+
var_name_on_devices_.emplace(varname, op_dev_id);
475+
}
476+
} else if (op.Type() == "concat") {
477+
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
478+
} else {
479+
PADDLE_ENFORCE(
480+
"the distribute training related op should be in [split_byref, "
481+
"concat].");
482+
}
483+
484+
PADDLE_ENFORCE(op_dev_id != -1,
485+
"can not find right place for distributed op: %s", op.Type());
486+
487+
CreateComputationalOp(result, op, op_dev_id);
504488
if (op.Type() == "concat") {
505489
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
506490
}
507491
}
508492

509-
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
510-
int device_id) const {
511-
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[device_id],
512-
op.Type(), places_[device_id]));
493+
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
494+
const OpDesc &op) const {
495+
int op_dev_id = -1;
496+
if (op.Type() == "send") {
497+
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
498+
// the variable name which contains .block means it was splited by
499+
// split_byref op
500+
// so that we can balance the variable blocks to all the pserver instances.
501+
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
502+
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
503+
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
504+
for (auto &varname : op.InputArgumentNames()) {
505+
var_name_on_devices_.emplace(varname, op_dev_id);
506+
}
507+
}
508+
} else if (op.Type() == "recv") {
509+
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
510+
for (auto &varname : op.OutputArgumentNames()) {
511+
var_name_on_devices_.emplace(varname, op_dev_id);
512+
}
513+
} else {
514+
// send_barrier and fetch_barrier op can be scheduled on device 0
515+
op_dev_id = 0;
516+
}
517+
518+
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
519+
op.Type());
520+
521+
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
522+
op.Type(), places_[op_dev_id]));
513523

514524
if (op.Type() == "send_barrier") {
515525
ConnectOp(result, result->ops_.back().get(), "send");
@@ -525,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
525535
"send, send_barrier. recv, fetch_barrier]");
526536
}
527537

528-
// TODO(Yancey1989): schedule rpc op on different place may
529-
// increate throughput
530-
CreateOpHandleIOs(result, op, device_id);
538+
CreateOpHandleIOs(result, op, op_dev_id);
531539
}
532540

533541
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6565

6666
bool IsScaleLossOp(const OpDesc &op) const;
6767

68-
void CreateRPCOp(SSAGraph *result, const OpDesc &op, int place_id) const;
69-
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op,
70-
int place_id) const;
68+
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
69+
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
7170

7271
/**
7372
* Is this operator as the end-point operator before/after send operator.
@@ -105,13 +104,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
105104
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
106105
size_t src_dev_id) const;
107106

108-
bool IsSparseGradient(
109-
const std::unordered_map<std::string, VarDesc *> &all_vars,
110-
const std::string &og) const;
107+
bool IsSparseGradient(const std::string &og) const;
108+
109+
size_t GetAppropriateDeviceID(
110+
const std::vector<std::string> &var_names) const;
111111

112112
private:
113113
BuildStrategy strategy_;
114+
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
114115
mutable std::unordered_map<std::string, int> var_name_on_devices_;
116+
mutable std::vector<int64_t> balance_vars_;
115117

116118
void SetCommunicationContext(OpHandleBase *op_handle,
117119
const platform::Place &p) const;

0 commit comments

Comments
 (0)