Skip to content

Commit 9923be5

Browse files
authored
Merge pull request #10546 from chengduoZH/feature/change_pe_strategy
Balance parameter_opt between cards
2 parents 43b6d4f + 54cbf79 commit 9923be5

File tree

7 files changed

+168
-45
lines changed

7 files changed

+168
-45
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
3737
const std::string &loss_var_name,
3838
const std::unordered_set<std::string> &params,
3939
const std::vector<Scope *> &local_scopes,
40-
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale)
40+
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
41+
bool balance_parameter_opt_between_cards)
4142
: loss_var_name_(loss_var_name),
4243
places_(places),
4344
local_scopes_(local_scopes),
44-
nccl_ctxs_(nccl_ctxs) {
45+
nccl_ctxs_(nccl_ctxs),
46+
balance_parameter_opt_between_cards_(
47+
balance_parameter_opt_between_cards) {
4548
#else
4649
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
4750
const std::vector<platform::Place> &places,
4851
const std::string &loss_var_name,
4952
const std::unordered_set<std::string> &params,
50-
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale)
53+
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
54+
bool balance_parameter_opt_between_cards)
5155
: loss_var_name_(loss_var_name),
5256
places_(places),
53-
local_scopes_(local_scopes) {
57+
local_scopes_(local_scopes),
58+
balance_parameter_opt_between_cards_(
59+
balance_parameter_opt_between_cards) {
5460
#endif
5561
for (auto &p : params) {
5662
grad_names_.insert(GradVarName(p));
@@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
124130
// Find "send" op first for split is in front of send.
125131
OpDesc *send_op = GetSendOpDesc(program);
126132

133+
size_t cur_device_id = 0;
134+
std::vector<std::unordered_set<std::string>> var_name_on_devices;
135+
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
136+
var_name_on_devices.resize(places_.size());
137+
bcast_var_name_set.resize(places_.size());
138+
127139
bool is_forwarding = true;
128140
for (auto *op : program.Block(0).AllOps()) {
129141
if (op->Type() == "send") {
@@ -139,24 +151,47 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
139151
}
140152
is_forwarding = false;
141153
} else {
142-
CreateComputationalOps(&result, *op, places_.size());
154+
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
155+
if (op_dev_id == -1) { // var on all device
156+
CreateComputationalOps(&result, *op, places_.size());
157+
} else {
158+
CreateComputationalOp(&result, *op, op_dev_id);
159+
for (auto &var_name : op->OutputArgumentNames()) {
160+
var_name_on_devices[op_dev_id].emplace(var_name);
161+
}
162+
}
143163
if (!is_forwarding && places_.size() > 1) {
144164
// Currently, we assume that once gradient is generated, it can be
145165
// broadcast, and each gradient is only broadcast once.
146166
for (auto &og : op->OutputArgumentNames()) {
147167
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
148-
if (IsSparseGradient(var_types, og)) {
149-
CreateReduceOp(&result, og, 0);
150-
CreateBroadcastOp(&result, og, 0);
168+
if (balance_parameter_opt_between_cards_) {
169+
CreateReduceOp(&result, og, cur_device_id);
170+
var_name_on_devices[cur_device_id].emplace(og);
171+
bcast_var_name_set[cur_device_id].emplace(
172+
og.substr(0, og.size() - strlen(kGradVarSuffix)));
173+
cur_device_id = (cur_device_id + 1) % places_.size();
151174
} else {
152-
InsertNCCLAllReduceOp(&result, og);
175+
if (IsSparseGradient(var_types, og)) {
176+
CreateReduceOp(&result, og, 0);
177+
CreateBroadcastOp(&result, og, 0);
178+
} else {
179+
InsertNCCLAllReduceOp(&result, og);
180+
}
153181
}
154182
}
155183
}
156184
}
157185
}
158186
}
159187

188+
// Insert BCast Ops
189+
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
190+
auto &to_bcast_set = bcast_var_name_set[dev_id];
191+
for (auto &bcast_name : to_bcast_set) {
192+
CreateBroadcastOp(&result, bcast_name, dev_id);
193+
}
194+
}
160195
/*
161196
Dependency graph has been constructed. However, there are still data
162197
harzaeds need to be handled.
@@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
265300
return is_pg_once;
266301
}
267302

303+
int MultiDevSSAGraphBuilder::GetOpDeviceID(
304+
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
305+
const OpDesc &op) const {
306+
if (!balance_parameter_opt_between_cards_) {
307+
return -1;
308+
}
309+
310+
int var_dev_id = -1;
311+
for (auto &var_name : op.InputArgumentNames()) {
312+
if (var_dev_id != -1) break;
313+
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
314+
if (var_name_on_devices[i].count(var_name)) {
315+
var_dev_id = static_cast<int>(i);
316+
break;
317+
}
318+
}
319+
}
320+
return var_dev_id;
321+
}
322+
268323
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
269324
for (size_t i = 0; i < places_.size(); ++i) {
270325
// Insert ScaleCost OpHandle

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
3636
const std::unordered_set<std::string> &params,
3737
const std::vector<Scope *> &local_scopes,
3838
platform::NCCLContextMap *nccl_ctxs,
39-
bool use_default_grad_scale);
39+
bool use_default_grad_scale,
40+
bool balance_parameter_opt_between_cards);
4041
#else
4142
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
4243
const std::string &loss_var_name,
4344
const std::unordered_set<std::string> &params,
4445
const std::vector<Scope *> &local_scopes,
45-
bool use_default_grad_scale);
46+
bool use_default_grad_scale,
47+
bool balance_parameter_opt_between_cards);
4648
#endif
4749

4850
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
@@ -60,6 +62,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
6062
#ifdef PADDLE_WITH_CUDA
6163
platform::NCCLContextMap *nccl_ctxs_;
6264
#endif
65+
bool balance_parameter_opt_between_cards_;
6366
bool use_default_grad_scale_;
6467

6568
bool IsScaleLossOp(const OpDesc &op) const;
@@ -84,6 +87,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
8487
const std::string &og,
8588
std::unordered_set<std::string> *og_has_been_broadcast) const;
8689

90+
int GetOpDeviceID(
91+
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
92+
const OpDesc &op) const;
93+
8794
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
8895

8996
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,

paddle/fluid/framework/parallel_executor.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
5858
const std::unordered_set<std::string> &bcast_vars,
5959
const ProgramDesc &main_program, const std::string &loss_var_name,
6060
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
61-
bool use_default_grad_scale)
61+
bool use_default_grad_scale, bool balance_parameter_opt_between_cards)
6262
: member_(new ParallelExecutorPrivate(places)) {
6363
member_->global_scope_ = scope;
6464

@@ -93,11 +93,12 @@ ParallelExecutor::ParallelExecutor(
9393
#ifdef PADDLE_WITH_CUDA
9494
details::MultiDevSSAGraphBuilder builder(
9595
member_->places_, loss_var_name, params, member_->local_scopes_,
96-
member_->nccl_ctxs_.get(), use_default_grad_scale);
96+
member_->nccl_ctxs_.get(), use_default_grad_scale,
97+
balance_parameter_opt_between_cards);
9798
#else
98-
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
99-
params, member_->local_scopes_,
100-
use_default_grad_scale);
99+
details::MultiDevSSAGraphBuilder builder(
100+
member_->places_, loss_var_name, params, member_->local_scopes_,
101+
use_default_grad_scale, balance_parameter_opt_between_cards);
101102
#endif
102103
auto graph = builder.Build(main_program);
103104

paddle/fluid/framework/parallel_executor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class ParallelExecutor {
4040
const ProgramDesc& main_program,
4141
const std::string& loss_var_name, Scope* scope,
4242
const std::vector<Scope*>& local_scopes,
43-
bool allow_op_delay, bool use_default_grad_scale);
43+
bool allow_op_delay, bool use_default_grad_scale,
44+
bool balance_parameter_opt_between_cards);
4445

4546
~ParallelExecutor();
4647

paddle/fluid/pybind/pybind.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,13 @@ All parameter, weight, gradient are variables in Paddle.
502502
const std::unordered_set<std::string> &bcast_vars,
503503
const ProgramDesc &main_program, const std::string &loss_var_name,
504504
Scope *scope, std::vector<Scope *> &local_scopes,
505-
bool allow_op_delay, bool use_default_grad_scale) {
505+
bool allow_op_delay, bool use_default_grad_scale,
506+
bool balance_parameter_opt_between_cards) {
506507
new (&self) ParallelExecutor(
507508
num_threads, use_event, places, params, bcast_vars,
508509
main_program, loss_var_name, scope, local_scopes,
509-
allow_op_delay, use_default_grad_scale);
510+
allow_op_delay, use_default_grad_scale,
511+
balance_parameter_opt_between_cards);
510512
})
511513
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
512514
// NOTE: even we return a vec<Scope*>* to Python use reference policy.

python/paddle/fluid/parallel_executor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def __init__(self,
3030
num_threads=None,
3131
allow_op_delay=False,
3232
share_vars_from=None,
33-
use_default_grad_scale=True):
33+
use_default_grad_scale=True,
34+
balance_parameter_opt_between_cards=False):
3435
"""
3536
ParallelExecutor can run program in parallel.
3637
@@ -51,6 +52,9 @@ def __init__(self,
5152
gradients of each device and scaled gradients would be
5253
aggregated. Otherwise, a customized scale value should be fed
5354
to the network.
55+
balance_parameter_opt_between_cards(bool, default True): Whether
56+
updating different gradients on different cards. Currently, it
57+
is not recommended.
5458
5559
Returns:
5660
A ParallelExecutor object.
@@ -129,7 +133,8 @@ def __init__(self,
129133
scope,
130134
local_scopes,
131135
allow_op_delay,
132-
use_default_grad_scale)
136+
use_default_grad_scale,
137+
balance_parameter_opt_between_cards)
133138

134139
self.scope = scope
135140

0 commit comments

Comments
 (0)