Skip to content

Commit 6d752ba

Browse files
committed
use get_appropriate_dev to schedule rpc op
1 parent 4444e79 commit 6d752ba

File tree

4 files changed

+54
-69
lines changed

4 files changed

+54
-69
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
142142

143143
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
144144
const ProgramDesc &program) const {
145-
VLOG(3) << "Building ....";
146145
std::unordered_map<std::string, VarDesc *> all_vars;
147146
for (auto *var : program.Block(0).AllVars()) {
148147
all_vars[var->Name()] = var;
@@ -162,36 +161,32 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
162161
auto send_vars = FindDistTrainSendVars(program);
163162
auto recv_vars = FindDistTrainRecvVars(program);
164163

165-
std::vector<std::unordered_set<std::string>> var_name_on_devices;
166164
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
167-
var_name_on_devices.resize(places_.size());
168165
bcast_var_name_set.resize(places_.size());
169166

170167
size_t cur_device_id = 0;
171168
std::vector<int64_t> balance_grads(places_.size(), 0);
172169

173-
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
174-
auto var_desc = all_vars.at(g_name);
175-
PADDLE_ENFORCE_NOT_NULL(var_desc);
176-
auto dim = framework::make_ddim(var_desc->GetShape());
177-
int64_t numel = framework::product(dim);
178-
PADDLE_ENFORCE_GE(numel, 0);
170+
auto get_appropriate_dev = [&](std::vector<std::string> var_names) -> size_t {
171+
int64_t numel_all = 0;
172+
for (auto var_name : var_names) {
173+
auto var_desc = all_vars.at(var_name);
174+
PADDLE_ENFORCE_NOT_NULL(var_desc);
175+
auto dim = framework::make_ddim(var_desc->GetShape());
176+
int64_t numel = framework::product(dim);
177+
PADDLE_ENFORCE_GT(numel, 0);
178+
numel_all += numel;
179+
}
180+
179181
auto smallest =
180182
std::min_element(std::begin(balance_grads), std::end(balance_grads));
181183
size_t dev_id =
182184
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
183-
balance_grads[dev_id] += numel;
185+
balance_grads[dev_id] += numel_all;
184186
return dev_id;
185187
};
186188

187189
bool is_forwarding = true;
188-
int rpc_op_device_id = 0;
189-
auto schedule_rpc_op = [&]() -> void {
190-
rpc_op_device_id++;
191-
if (rpc_op_device_id >= static_cast<int>(places_.size())) {
192-
rpc_op_device_id = 0;
193-
}
194-
};
195190

196191
for (auto *op : program.Block(0).AllOps()) {
197192
if (boost::get<int>(
@@ -200,37 +195,40 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
200195
// append rpc op if program is distributed trainer main program.
201196
// always use the first device
202197
if (op->Type() == "send_vars") {
203-
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
204-
if (got == remote_vars_devices_.end()) {
205-
schedule_rpc_op();
206-
} else {
207-
rpc_op_device_id = got->second;
198+
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
199+
if (op_dev_id == -1) {
200+
op_dev_id = get_appropriate_dev(op->InputArgumentNames());
201+
for (auto &varname : op->InputArgumentNames()) {
202+
var_name_on_devices_.emplace(varname, op_dev_id);
203+
}
208204
}
209-
CreateRPCOp(&result, *op, rpc_op_device_id);
205+
CreateRPCOp(&result, *op, op_dev_id);
210206
} else if (op->Type() == "recv") {
211-
schedule_rpc_op();
207+
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
212208
for (auto &varname : op->OutputArgumentNames()) {
213-
remote_vars_devices_.insert({varname, rpc_op_device_id});
209+
var_name_on_devices_.emplace(varname, op_dev_id);
214210
}
215-
CreateRPCOp(&result, *op, rpc_op_device_id);
211+
CreateRPCOp(&result, *op, op_dev_id);
216212
} else {
213+
// send_barrier and fetch_barrier op would run on device 0
217214
CreateRPCOp(&result, *op, 0);
218215
}
219216
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
220217
if (op->Type() == "split_byref") {
221-
schedule_rpc_op();
218+
int op_dev_id = get_appropriate_dev(op->OutputArgumentNames());
222219
for (auto &varname : op->OutputArgumentNames()) {
223-
remote_vars_devices_.insert({varname, rpc_op_device_id});
220+
var_name_on_devices_.emplace(varname, op_dev_id);
224221
}
225-
CreateDistTrainOp(&result, *op, rpc_op_device_id);
226-
}
227-
if (op->Type() == "concat") {
228-
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
229-
PADDLE_ENFORCE(got != remote_vars_devices_.end(),
222+
CreateDistTrainOp(&result, *op, op_dev_id);
223+
} else if (op->Type() == "concat") {
224+
int op_dev_id = GetVarDeviceID(op->InputArgumentNames()[0]);
225+
PADDLE_ENFORCE(op_dev_id != -1,
230226
"can not find right place to concatenate received var.");
231-
CreateDistTrainOp(&result, *op, got->second);
227+
CreateDistTrainOp(&result, *op, op_dev_id);
232228
} else {
233-
CreateDistTrainOp(&result, *op, 0);
229+
PADDLE_ENFORCE(
230+
"the distribute training related op should be in [split_byref, "
231+
"concat].");
234232
}
235233
} else if (IsScaleLossOp(*op)) {
236234
// user can customize loss@grad if not use_default_grad_scale_
@@ -240,13 +238,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
240238
}
241239
is_forwarding = false;
242240
} else {
243-
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
241+
int op_dev_id = GetOpDeviceID(*op);
244242
if (op_dev_id == -1) { // var on all device
245243
CreateComputationalOps(&result, *op, places_.size());
246244
} else {
247245
CreateComputationalOp(&result, *op, op_dev_id);
248246
for (auto &var_name : op->OutputArgumentNames()) {
249-
var_name_on_devices[op_dev_id].emplace(var_name);
247+
var_name_on_devices_.emplace(var_name, op_dev_id);
250248
}
251249
}
252250
if (!is_forwarding && places_.size() > 1) {
@@ -269,9 +267,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
269267

270268
switch (strategy_.reduce_) {
271269
case BuildStrategy::ReduceStrategy::kReduce:
272-
cur_device_id = get_appropriate_dev(g_name);
270+
cur_device_id = get_appropriate_dev({g_name});
273271
CreateReduceOp(&result, g_name, cur_device_id);
274-
var_name_on_devices[cur_device_id].emplace(g_name);
272+
var_name_on_devices_.emplace(g_name, cur_device_id);
275273
bcast_var_name_set[cur_device_id].emplace(p_name);
276274
break;
277275
case BuildStrategy::ReduceStrategy::kAllReduce:
@@ -402,24 +400,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
402400
return is_pg_once;
403401
}
404402

405-
int MultiDevSSAGraphBuilder::GetOpDeviceID(
406-
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
407-
const OpDesc &op) const {
403+
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
408404
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
409405
return -1;
410406
}
411407

412-
int var_dev_id = -1;
413-
for (auto &var_name : op.InputArgumentNames()) {
414-
if (var_dev_id != -1) break;
415-
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
416-
if (var_name_on_devices[i].count(var_name)) {
417-
var_dev_id = static_cast<int>(i);
418-
break;
419-
}
408+
for (auto &varname : op.InputArgumentNames()) {
409+
int dev_id = GetVarDeviceID(varname);
410+
if (dev_id != -1) {
411+
return dev_id;
420412
}
421413
}
422-
return var_dev_id;
414+
return -1;
415+
}
416+
417+
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
418+
auto got = var_name_on_devices_.find(varname);
419+
return got == var_name_on_devices_.end() ? -1 : got->second;
423420
}
424421

425422
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4747
#endif
4848

4949
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
50-
51-
int GetRemoteVarDeviceId(const std::string &var_name) const override {
52-
auto got = remote_vars_devices_.find(var_name);
53-
if (got != remote_vars_devices_.end()) {
54-
return got->second;
55-
}
56-
return -1;
57-
}
50+
int GetVarDeviceID(const std::string &varname) const;
5851

5952
private:
6053
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
@@ -105,9 +98,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
10598
const std::string &og,
10699
std::unordered_set<std::string> *og_has_been_broadcast) const;
107100

108-
int GetOpDeviceID(
109-
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
110-
const OpDesc &op) const;
101+
int GetOpDeviceID(const OpDesc &op) const;
111102

112103
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
113104

@@ -120,7 +111,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
120111

121112
private:
122113
BuildStrategy strategy_;
123-
mutable std::unordered_map<std::string, int> remote_vars_devices_;
114+
mutable std::unordered_map<std::string, int> var_name_on_devices_;
124115

125116
void SetCommunicationContext(OpHandleBase *op_handle,
126117
const platform::Place &p) const;

paddle/fluid/framework/details/ssa_graph_builder.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ class SSAGraphBuilder {
3030
SSAGraphBuilder() {}
3131
virtual ~SSAGraphBuilder() {}
3232
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
33-
virtual int GetRemoteVarDeviceId(const std::string &var_name) const {
34-
return -1;
35-
}
33+
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }
3634

3735
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
3836

paddle/fluid/framework/parallel_executor.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,8 @@ void ParallelExecutor::BCastParamsToGPUs(
161161
}
162162
auto &nccl_ctx = member_->nccl_ctxs_->at(place);
163163

164-
if (builder_.get() != nullptr &&
165-
builder_->GetRemoteVarDeviceId(var) != -1) {
166-
int place_id = builder_->GetRemoteVarDeviceId(var);
164+
if (builder_.get() != nullptr && builder_->GetVarDeviceID(var) != -1) {
165+
int place_id = builder_->GetVarDeviceID(var);
167166
platform::dynload::ncclBcast(buffer, numel, data_type, place_id,
168167
nccl_ctx.comm_, nccl_ctx.stream());
169168
} else {

0 commit comments

Comments
 (0)