Skip to content

Commit 3f1169f

Browse files
wangxicodinggongweibao
authored andcommitted
Fix dgc clip & rampup step, test=release/1.6 (#21519)
1 parent 0e63746 commit 3f1169f

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

paddle/fluid/operators/dgc_op.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ inline float get_period_sparcity(const std::vector<float>& sparsity,
2828

2929
size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps);
3030
if (idx >= sparsity.size()) {
31-
return 0.999;
31+
idx = sparsity.size() - 1;
3232
}
3333

3434
PADDLE_ENFORCE_LT(idx, sparsity.size());
@@ -102,8 +102,9 @@ class DGCOpKernel : public framework::OpKernel<T> {
102102
}
103103

104104
float ratio =
105-
1 - get_period_sparcity(sparsity, static_cast<float>(*current_step),
106-
rampup_step);
105+
1 - get_period_sparcity(
106+
sparsity, static_cast<float>(*current_step - rampup_begin_step),
107+
rampup_step);
107108
PADDLE_ENFORCE_GE(ratio, 0.0);
108109
PADDLE_ENFORCE_LT(ratio, 1.0);
109110
int k = static_cast<int>(g->numel() * ratio);

python/paddle/fluid/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@ def __init__(self,
947947
self._momentum = momentum
948948
self._use_nesterov = bool(use_nesterov)
949949

950+
assert rampup_begin_step >= 0, "rampup_begin_step must >= 0"
950951
self._rampup_begin_step = rampup_begin_step
951952
self._rampup_step = rampup_step
952953
self._sparsity = sparsity
@@ -963,8 +964,7 @@ def __init__(self,
963964

964965
self._local_grad_clip_norm = local_grad_clip_norm
965966
self._num_trainers = num_trainers
966-
self._clip_norm = local_grad_clip_norm / (num_trainers *
967-
num_trainers)
967+
self._clip_norm = local_grad_clip_norm * (num_trainers**-0.5)
968968

969969
self._get_dgc_regularization_param()
970970

python/paddle/fluid/tests/unittests/test_dgc_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def check_dgc_momentum_optimizer(self,
6767
learning_rate=learning_rate,
6868
momentum=0.2,
6969
rampup_begin_step=0,
70+
local_grad_clip_norm=1.0,
71+
num_trainers=2,
7072
regularization=regularization)
7173
mean_out = block.create_var(
7274
dtype="float32", shape=[1], lod_level=0, name="mean.out")

0 commit comments

Comments
 (0)