Skip to content

Commit f3e4e42

Browse files
authored
Merge pull request #10130 from reyoung/feature/skip_loss
Add customize_loss_grad option to PE
2 parents 3863c6a + 55feba9 commit f3e4e42

File tree

6 files changed

+27
-17
lines changed

6 files changed

+27
-17
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
3434
const std::vector<platform::Place> &places,
3535
const std::string &loss_var_name,
3636
const std::unordered_set<std::string> &params,
37-
const std::vector<Scope *> &local_scopes,
37+
const std::vector<Scope *> &local_scopes, bool skip_scale_loss,
3838
platform::NCCLContextMap *nccl_ctxs)
3939
: loss_var_name_(loss_var_name),
4040
places_(places),
@@ -45,14 +45,15 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
4545
const std::vector<platform::Place> &places,
4646
const std::string &loss_var_name,
4747
const std::unordered_set<std::string> &params,
48-
const std::vector<Scope *> &local_scopes)
48+
const std::vector<Scope *> &local_scopes, bool skip_scale_loss)
4949
: loss_var_name_(loss_var_name),
5050
places_(places),
5151
local_scopes_(local_scopes) {
5252
#endif
5353
for (auto &p : params) {
5454
grad_names_.insert(GradVarName(p));
5555
}
56+
skip_scale_loss_ = skip_scale_loss;
5657
}
5758

5859
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
@@ -95,7 +96,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
9596
// always use the first device
9697
CreateSendOp(&result, *op);
9798
} else if (IsScaleLossOp(*op)) {
98-
CreateScaleLossGradOp(&result);
99+
if (!skip_scale_loss_) {
100+
CreateScaleLossGradOp(&result);
101+
}
99102
is_forwarding = false;
100103
} else {
101104
CreateComputationalOps(&result, *op);

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
3434
const std::string &loss_var_name,
3535
const std::unordered_set<std::string> &params,
3636
const std::vector<Scope *> &local_scopes,
37+
bool skip_scale_loss,
3738
platform::NCCLContextMap *nccl_ctxs);
3839
#else
3940
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
4041
const std::string &loss_var_name,
4142
const std::unordered_set<std::string> &params,
42-
const std::vector<Scope *> &local_scopes);
43+
const std::vector<Scope *> &local_scopes,
44+
bool skip_scale_loss);
4345
#endif
4446

4547
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
@@ -57,6 +59,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
5759
#ifdef PADDLE_WITH_CUDA
5860
platform::NCCLContextMap *nccl_ctxs_;
5961
#endif
62+
bool skip_scale_loss_;
6063

6164
bool IsScaleLossOp(const OpDesc &op) const;
6265

paddle/fluid/framework/parallel_executor.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ ParallelExecutor::ParallelExecutor(
5757
const std::unordered_set<std::string> &params,
5858
const std::unordered_set<std::string> &bcast_vars,
5959
const ProgramDesc &main_program, const std::string &loss_var_name,
60-
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay)
60+
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
61+
bool customize_scale_loss)
6162
: member_(new ParallelExecutorPrivate(places)) {
6263
member_->global_scope_ = scope;
6364

@@ -90,12 +91,13 @@ ParallelExecutor::ParallelExecutor(
9091
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
9192
// ncclOp
9293
#ifdef PADDLE_WITH_CUDA
93-
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
94-
params, member_->local_scopes_,
95-
member_->nccl_ctxs_.get());
94+
details::MultiDevSSAGraphBuilder builder(
95+
member_->places_, loss_var_name, params, member_->local_scopes_,
96+
customize_scale_loss, member_->nccl_ctxs_.get());
9697
#else
9798
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
98-
params, member_->local_scopes_);
99+
params, member_->local_scopes_,
100+
customize_scale_loss);
99101
#endif
100102
auto graph = builder.Build(main_program);
101103

paddle/fluid/framework/parallel_executor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ParallelExecutor {
4040
const ProgramDesc& main_program,
4141
const std::string& loss_var_name, Scope* scope,
4242
const std::vector<Scope*>& local_scopes,
43-
bool allow_op_delay);
43+
bool allow_op_delay, bool customize_scale_loss);
4444

4545
~ParallelExecutor();
4646

paddle/fluid/pybind/pybind.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,11 @@ All parameter, weight, gradient are variables in Paddle.
502502
const std::unordered_set<std::string> &bcast_vars,
503503
const ProgramDesc &main_program, const std::string &loss_var_name,
504504
Scope *scope, std::vector<Scope *> &local_scopes,
505-
bool allow_op_delay) {
506-
new (&self)
507-
ParallelExecutor(num_threads, use_event, places, params,
508-
bcast_vars, main_program, loss_var_name,
509-
scope, local_scopes, allow_op_delay);
505+
bool allow_op_delay, bool customize_loss_grad) {
506+
new (&self) ParallelExecutor(num_threads, use_event, places,
507+
params, bcast_vars, main_program,
508+
loss_var_name, scope, local_scopes,
509+
allow_op_delay, customize_loss_grad);
510510
})
511511
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
512512
// NOTE: even we return a vec<Scope*>* to Python use reference policy.

python/paddle/fluid/parallel_executor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def __init__(self,
2929
main_program=None,
3030
num_threads=None,
3131
allow_op_delay=False,
32-
share_vars_from=None):
32+
share_vars_from=None,
33+
customize_loss_grad=False):
3334
"""
3435
ParallelExecutor can run program in parallel.
3536
@@ -122,7 +123,8 @@ def __init__(self,
122123
loss_name if loss_name else '',
123124
scope,
124125
local_scopes,
125-
allow_op_delay)
126+
allow_op_delay,
127+
customize_loss_grad)
126128
self.scope = scope
127129

128130
def run(self, fetch_list, feed=None, feed_dict=None):

0 commit comments

Comments
 (0)