Skip to content

Commit 0251320

Browse files
authored
fix find_unused_parameters default value (#32829)
fix error log for reducer fix doc fix bug of utest fix spawn fix converage
1 parent 09b18a4 commit 0251320

File tree

10 files changed

+95
-65
lines changed

10 files changed

+95
-65
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ message DistributedStrategy {
172172
optional bool fp16_allreduce = 25 [ default = false ];
173173
optional bool sharding = 26 [ default = false ];
174174
optional float last_comm_group_size_MB = 27 [ default = 1 ];
175-
optional bool find_unused_parameters = 28 [ default = true ];
175+
optional bool find_unused_parameters = 28 [ default = false ];
176176
optional bool tensor_parallel = 29 [ default = false ];
177177

178178
optional RecomputeConfig recompute_configs = 101;

paddle/fluid/imperative/reducer.cc

Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
297297
is_sparse_gradient_(is_sparse_gradient),
298298
parallel_ctx_(parallel_ctx),
299299
group_size_limits_(group_size_limits),
300-
find_unused_vars_(find_unused_vars) {
300+
find_unused_vars_each_step_(find_unused_vars) {
301301
VLOG(3) << "Start construct the Reducer ...";
302302
nrings_ = parallel_ctx->GetNRings();
303303
nranks_ = parallel_ctx->GetNRanks();
@@ -457,42 +457,8 @@ void Reducer::PrepareDeps(const std::unordered_set<GradOpNode *> &init_nodes) {
457457
}
458458
}
459459

460-
// After each batch is calculated, the counter of each group(group.pending_)
461-
// and allreudce sequence counter(next_group_) will be cleaned up again.
462-
void Reducer::PrepareForBackward(
460+
void Reducer::TraverseBackwardGraph(
463461
const std::vector<std::shared_ptr<imperative::VarBase>> &outputs) {
464-
VLOG(3) << "after forward, then reset count for backward.";
465-
next_group_ = 0;
466-
std::for_each(groups_.begin(), groups_.end(), [](Group &group) {
467-
group.pending_ = group.variable_indices_.size();
468-
group.sparse_contents_ = nullptr;
469-
});
470-
471-
// reinitialize vars_marked_ready_ for next iteration
472-
vars_marked_ready_.clear();
473-
vars_marked_ready_.resize(vars_.size(), false);
474-
475-
PADDLE_ENFORCE_EQ(
476-
groups_need_finalize_, false,
477-
platform::errors::PreconditionNotMet(
478-
"A serious error has occurred here. There may be several reasons: "
479-
"1) Please note that all forward outputs derived from the module "
480-
"parameters must participate in the calculation of losses and "
481-
"subsequent gradient calculations. If not, the wrapper will hang, "
482-
"waiting for autograd to generate gradients for these parameters. "
483-
"you can use detach or stop_gradient to make the unused parameters "
484-
"detached from the autograd graph. "
485-
"2) Used multiple forwards and one backward. You may be able to wrap "
486-
"multiple forwards in a model."));
487-
488-
// The first var to trigger the unused parameter
489-
has_marked_unused_vars_ = false;
490-
unused_vars_.clear();
491-
492-
if (!find_unused_vars_) {
493-
return;
494-
}
495-
496462
node_deps_.clear();
497463
std::queue<std::shared_ptr<GradOpNode>> q;
498464
std::unordered_set<VariableWrapper *> var_visited;
@@ -554,8 +520,50 @@ void Reducer::PrepareForBackward(
554520
<< "] is not used";
555521
}
556522
}
523+
}
557524

558-
if (unused_vars_.empty()) {
525+
// After each batch is calculated, the counter of each group(group.pending_)
526+
// and allreudce sequence counter(next_group_) will be cleaned up again.
527+
void Reducer::PrepareForBackward(
528+
const std::vector<std::shared_ptr<imperative::VarBase>> &outputs) {
529+
VLOG(3) << "after forward, then reset count for backward.";
530+
next_group_ = 0;
531+
std::for_each(groups_.begin(), groups_.end(), [](Group &group) {
532+
group.pending_ = group.variable_indices_.size();
533+
group.sparse_contents_ = nullptr;
534+
});
535+
536+
// reinitialize vars_marked_ready_ for next iteration
537+
vars_marked_ready_.clear();
538+
vars_marked_ready_.resize(vars_.size(), false);
539+
540+
PADDLE_ENFORCE_EQ(
541+
groups_need_finalize_, false,
542+
platform::errors::PreconditionNotMet(
543+
"A serious error has occurred here. Please "
544+
"set find_unused_parameters=True to traverse backward graph "
545+
"in each step to prepare reduce in advance. If you have "
546+
"set, There may be several reasons for this error: "
547+
"1) Please note that all forward outputs derived from the module "
548+
"parameters must participate in the calculation of losses and "
549+
"subsequent gradient calculations. If not, the wrapper will hang, "
550+
"waiting for autograd to generate gradients for these parameters. "
551+
"you can use detach or stop_gradient to make the unused parameters "
552+
"detached from the autograd graph. "
553+
"2) Used multiple forwards and one backward. You may be able to wrap "
554+
"multiple forwards in a model."));
555+
556+
// The first var to trigger the unused parameter
557+
has_marked_unused_vars_ = false;
558+
559+
if (find_unused_vars_once_ || find_unused_vars_each_step_) {
560+
unused_vars_.clear();
561+
TraverseBackwardGraph(outputs);
562+
// only check once in first step
563+
find_unused_vars_once_ = false;
564+
}
565+
566+
if (find_unused_vars_each_step_ && unused_vars_.empty()) {
559567
LOG_FIRST_N(WARNING, 1)
560568
<< "All parameters are involved in the backward pass. "
561569
"It is recommended to set find_unused_parameters to False "
@@ -564,7 +572,9 @@ void Reducer::PrepareForBackward(
564572
"will occur. Please make it clear that in the subsequent "
565573
"training, there will be no parameters that are not used "
566574
"in the backward pass, and then set find_unused_parameters";
567-
} else if (unused_vars_.size() == vars_.size()) {
575+
}
576+
577+
if (unused_vars_.size() == vars_.size()) {
568578
LOG_FIRST_N(WARNING, 1)
569579
<< "There is no parameter in the device involved "
570580
"in the backward calculation. If there are "
@@ -595,13 +605,13 @@ void Reducer::AddDistHook(size_t var_index) {
595605

596606
local_used_vars_[var_index] = 1;
597607

598-
// rebuild group when find_unused_vars_ is false
608+
// rebuild group when find_unused_vars_each_step_ is false
599609
if (NeedRebuildGroup()) {
600610
rebuild_vars_.push_back(vars_[var_index]);
601611
rebuild_var_indices_.push_back(var_index);
602612
}
603613

604-
if (!has_marked_unused_vars_ && find_unused_vars_) {
614+
if (!has_marked_unused_vars_) {
605615
has_marked_unused_vars_ = true;
606616
for (const auto &unused_index : unused_vars_) {
607617
MarkVarReady(unused_index, false);
@@ -622,7 +632,9 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
622632
if (vars_marked_ready_[var_index]) {
623633
auto error_info = string::Sprintf(
624634
"Error happened, when parameter[%d][%s] has been ready before. "
625-
"There may be several reasons for this error: "
635+
"Please set find_unused_parameters=True to traverse backward graph "
636+
"in each step to prepare reduce in advance. If you have set, "
637+
"there may be several reasons for this error: "
626638
"1) In multiple reentrant backward phase, some parameters are reused."
627639
"2) Using model parameters outside of forward function. Please "
628640
"make sure that model parameters are not shared in concurrent "
@@ -690,10 +702,16 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
690702
}
691703
} else {
692704
// process sparse group
693-
PADDLE_ENFORCE_EQ(HasGrad(var_index), true,
694-
platform::errors::PreconditionNotMet(
695-
"The sparse parameter[%d][%s] must have a gradient",
696-
var_index, vars_[var_index]->Name()));
705+
PADDLE_ENFORCE_EQ(
706+
HasGrad(var_index), true,
707+
platform::errors::PreconditionNotMet(
708+
"The sparse parameter[%d][%s] should have gradient. "
709+
"Currently, DataParallel does not support sparse "
710+
"parameters without generating gradients during training. "
711+
"For example, if is_sparese=True is used in Embedding, "
712+
"the current step of this parameter cannot generate gradient "
713+
"because of stop_gradient/detatch, where error will occur.",
714+
var_index, vars_[var_index]->Name()));
697715
auto var_base = vars_[var_index]->GradVarBase();
698716
// need to check tensor type
699717
PADDLE_ENFORCE_EQ(
@@ -943,7 +961,7 @@ void Reducer::FinalizeBackward() {
943961
InitializeGroups(group_indices_);
944962
}
945963

946-
if (find_unused_vars_) {
964+
if (find_unused_vars_each_step_) {
947965
// TODO(liuyuhui) support xpu about Tensorcopy/TensorFromVector/TensorToVector
948966
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
949967
ProcessUnusedDenseVars();

paddle/fluid/imperative/reducer.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,16 @@ class Reducer {
162162
std::vector<std::vector<size_t>> RebuildGruops();
163163

164164
inline bool NeedRebuildGroup() {
165-
return !has_rebuilt_group_ && !find_unused_vars_;
165+
return !has_rebuilt_group_ && !find_unused_vars_each_step_;
166166
}
167167

168168
void ProcessUnusedDenseVars();
169169

170170
bool HasGrad(size_t var_index);
171171

172+
void TraverseBackwardGraph(
173+
const std::vector<std::shared_ptr<imperative::VarBase>>& outputs);
174+
172175
private:
173176
std::vector<std::shared_ptr<imperative::VarBase>> vars_;
174177
std::vector<std::vector<size_t>> group_indices_;
@@ -195,7 +198,8 @@ class Reducer {
195198
std::unordered_map<VariableWrapper*, size_t> var_index_map_;
196199
std::vector<size_t> unused_vars_;
197200
bool has_marked_unused_vars_{false};
198-
bool find_unused_vars_{false};
201+
bool find_unused_vars_each_step_{false};
202+
bool find_unused_vars_once_{true};
199203
bool groups_need_finalize_{false};
200204
#ifdef PADDLE_WITH_XPU_BKCL
201205
// comm_pool_ is used for scheduling allreduce in multi Kunlun cards training.

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def find_unused_parameters(self):
626626
Indicating whether we are using find_unused_parameters to
627627
find unused parameters in DataParallel.
628628
629-
Default value: True
629+
Default value: False
630630
631631
Examples:
632632

python/paddle/fluid/dygraph/parallel.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -417,14 +417,15 @@ class DataParallel(layers.Layer):
417417
Note that setting the find_unused_parameters to True
418418
will affect computing performance. Therefore, if all parameters
419419
are sure to participate in the loss calculation and the
420-
autograd graph construction, please set it False. Default: True.
420+
autograd graph construction, please set it False. Default: False.
421421
422422
Returns:
423423
Layer: The data paralleled module.
424424
425425
Examples:
426426
.. code-block:: python
427-
427+
428+
# required: distributed
428429
import paddle
429430
import paddle.nn as nn
430431
import paddle.optimizer as opt
@@ -474,7 +475,7 @@ def __init__(self,
474475
strategy=None,
475476
comm_buffer_size=25,
476477
last_comm_buffer_size=1,
477-
find_unused_parameters=True):
478+
find_unused_parameters=False):
478479
super(DataParallel,
479480
self).__init__(layers.full_name() + "_data_parallel")
480481

@@ -576,12 +577,8 @@ def _find_varbase(self, obj):
576577
def forward(self, *inputs, **kwargs):
577578
outputs = self._layers(*inputs, **kwargs)
578579
if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad:
579-
if self.find_unused_parameters:
580-
self._reducer.prepare_for_backward(
581-
list(self._find_varbase(outputs)))
582-
else:
583-
self._reducer.prepare_for_backward(list(self._find_varbase([])))
584-
580+
self._reducer.prepare_for_backward(
581+
list(self._find_varbase(outputs)))
585582
return outputs
586583

587584
@deprecated(

python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def test_multiple_gpus(self):
7474
state_dict = model_a.state_dict()
7575
model_b.set_state_dict(state_dict)
7676

77-
model_a = paddle.DataParallel(model_a)
78-
model_b = paddle.DataParallel(model_b)
77+
model_a = paddle.DataParallel(model_a, find_unused_parameters=True)
78+
model_b = paddle.DataParallel(model_b, find_unused_parameters=True)
7979

8080
ones_input = paddle.ones(shape=(batch, in_dim))
8181
ones_input.stop_gradient = True

python/paddle/fluid/tests/unittests/spawn_runner_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class SpawnAssistTestArgs(object):
2828
update_method = "local"
2929
trainer_id = 0
30+
find_unused_parameters = False
3031

3132

3233
class TestDistSpawnRunner(unittest.TestCase):

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,10 @@ def run_trainer_with_spawn(self, args):
548548
# 4. train model
549549
model, train_reader, opt = self.get_model()
550550
if args.update_method == "nccl2":
551-
model = paddle.DataParallel(model)
551+
if args.find_unused_parameters:
552+
model = paddle.DataParallel(model, find_unused_parameters=True)
553+
else:
554+
model = paddle.DataParallel(model, find_unused_parameters=False)
552555

553556
out_losses = []
554557
for step_id, data in enumerate(train_reader()):
@@ -581,8 +584,8 @@ def run_use_fleet_api_trainer(self, args):
581584

582585
# set strategy
583586
strategy = fleet.DistributedStrategy()
584-
if not args.find_unused_parameters:
585-
strategy.find_unused_parameters = False
587+
if args.find_unused_parameters:
588+
strategy.find_unused_parameters = True
586589

587590
# 3. init parallel env
588591
if args.update_method == "nccl2" or "bkcl":
@@ -737,7 +740,7 @@ def setUp(self):
737740
self._save_model = False
738741
self._fuse_all_reduce = None
739742
self._accumulate_gradient = False
740-
self._find_unused_parameters = True
743+
self._find_unused_parameters = False
741744
self._setup_config()
742745

743746
global DIST_UT_PORT

python/paddle/fluid/tests/unittests/test_parallel_dygraph_control_flow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _setup_config(self):
3030
self._sync_mode = False
3131
self._nccl2_mode = True
3232
self._dygraph = True
33+
self._find_unused_parameters = True
3334

3435
def test_net(self):
3536
if fluid.core.is_compiled_with_cuda():
@@ -46,6 +47,7 @@ def _setup_config(self):
4647
self._nccl2_mode = True
4748
self._dygraph = True
4849
self._use_fleet_api = True
50+
self._find_unused_parameters = True
4951

5052

5153
class TestFleetDygraphControlFlowSameAccGrad(TestDygraphControlFlowSame):
@@ -54,13 +56,15 @@ def _setup_config(self):
5456
self._nccl2_mode = True
5557
self._dygraph = True
5658
self._accumulate_gradient = True
59+
self._find_unused_parameters = True
5760

5861

5962
class TestDygraphControlFlowDiff(TestDistBase):
6063
def _setup_config(self):
6164
self._sync_mode = False
6265
self._nccl2_mode = True
6366
self._dygraph = True
67+
self._find_unused_parameters = True
6468

6569
def test_net(self):
6670
if fluid.core.is_compiled_with_cuda():
@@ -77,6 +81,7 @@ def _setup_config(self):
7781
self._nccl2_mode = True
7882
self._dygraph = True
7983
self._use_fleet_api = True
84+
self._find_unused_parameters = True
8085

8186

8287
class TestFleetDygraphControlFlowDiffAccGrad(TestDygraphControlFlowDiff):
@@ -85,6 +90,7 @@ def _setup_config(self):
8590
self._nccl2_mode = True
8691
self._dygraph = True
8792
self._accumulate_gradient = True
93+
self._find_unused_parameters = True
8894

8995

9096
if __name__ == "__main__":

python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _setup_config(self):
3131
self._sync_mode = False
3232
self._nccl2_mode = True
3333
self._dygraph = True
34+
self._find_unused_parameters = True
3435

3536
def test_mnist(self):
3637
if fluid.core.is_compiled_with_cuda():

0 commit comments

Comments
 (0)