Skip to content

Commit 79afa2e

Browse files
authored
Merge pull request #126 from kozistr/refactor/ranger21-optimizer
[Refactor] Ranger21 optimizer
2 parents 84a24df + 90ce1fd commit 79afa2e

File tree

2 files changed

+41
-35
lines changed

2 files changed

+41
-35
lines changed

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.nn import functional as f
66
from torch.optim import Optimizer
77

8-
from pytorch_optimizer.base.exception import NegativeLRError, NoSparseGradientError, ZeroParameterSizeError
8+
from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
99
from pytorch_optimizer.base.optimizer import BaseOptimizer
1010
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
1111
from pytorch_optimizer.optimizer.agc import agc
@@ -73,6 +73,7 @@ def __init__( # pylint: disable=R0913
7373
eps: float = 1e-8,
7474
):
7575
self.lr = lr
76+
self.min_lr = warm_down_min_lr
7677
self.beta0 = beta0
7778
self.betas = betas
7879
self.use_softplus = use_softplus
@@ -96,7 +97,6 @@ def __init__( # pylint: disable=R0913
9697
# learning rate
9798
self.starting_lr = lr
9899
self.current_lr = lr
99-
self.min_lr = warm_down_min_lr
100100

101101
defaults: DEFAULTS = {
102102
'lr': lr,
@@ -123,6 +123,7 @@ def __init__( # pylint: disable=R0913
123123

124124
def validate_parameters(self):
125125
self.validate_learning_rate(self.lr)
126+
self.validate_learning_rate(self.min_lr)
126127
self.validate_betas(self.betas)
127128
self.validate_beta0(self.beta0)
128129
self.validate_weight_decay(self.weight_decay)
@@ -169,7 +170,7 @@ def warm_up_dampening(self, lr: float, step: int) -> float:
169170

170171
return new_lr
171172

172-
def get_warm_down(self, lr: float, iteration: int) -> float:
173+
def warm_down(self, lr: float, iteration: int) -> float:
173174
if iteration < self.start_warm_down:
174175
return lr
175176

@@ -182,6 +183,7 @@ def get_warm_down(self, lr: float, iteration: int) -> float:
182183

183184
new_lr: float = self.starting_lr - self.warm_down_lr_delta * warm_down_pct
184185
new_lr = max(new_lr, self.min_lr)
186+
185187
self.current_lr = new_lr
186188

187189
return new_lr
@@ -198,6 +200,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
198200

199201
# Phase 1 - Accumulate all the variance_ma_sum to use in stable weight decay
200202
for group in self.param_groups:
203+
if 'step' in group:
204+
group['step'] += 1
205+
else:
206+
group['step'] = 1
207+
201208
beta1, beta2 = group['betas']
202209
for p in group['params']:
203210
if p.grad is None:
@@ -214,7 +221,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
214221

215222
state = self.state[p]
216223
if len(state) == 0:
217-
state['step'] = 0
218224
state['grad_ma'] = torch.zeros_like(p)
219225
state['variance_ma'] = torch.zeros_like(p)
220226
state['lookahead_params'] = torch.empty_like(p)
@@ -227,17 +233,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
227233
grad = centralize_gradient(grad, gc_conv_only=False)
228234
grad = normalize_gradient(grad)
229235

230-
state['step'] += 1
231-
232-
bias_correction2 = 1.0 - beta2 ** state['step']
236+
bias_correction2 = 1.0 - beta2 ** group['step']
233237

234238
# second moment estimation
235239
# using positive-negative momentum and bias correction
236240
variance_ma = state['variance_ma']
237241
variance_ma.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
238242
variance_ma_sum += (variance_ma / bias_correction2).sum()
239243

240-
# stable weight decay
241244
if param_size == 0:
242245
raise ZeroParameterSizeError()
243246

@@ -247,42 +250,32 @@ def step(self, closure: CLOSURE = None) -> LOSS:
247250

248251
# Phase 2 - Apply weight decay and step
249252
for group in self.param_groups:
250-
if len(self.state) == 0:
251-
continue
252-
p = next(iter(self.state.keys()))
253-
lr = group["lr"]
254-
step = self.state[p]["step"]
255-
253+
lr = group['lr']
256254
beta1, beta2 = group['betas']
257-
bias_correction1 = 1.0 - beta1 ** step # fmt: skip
258-
bias_correction2_sq = math.sqrt(1.0 - beta2 ** step) # fmt: skip
259255

260-
noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2) # fmt: skip
261-
262-
# warm up
263-
lr = self.warm_up_dampening(lr, step)
264-
265-
# warm down
266-
lr = self.get_warm_down(lr, step)
267-
if lr < 0.0:
268-
raise NegativeLRError(lr)
256+
bias_correction1 = 1.0 - beta1 ** group['step'] # fmt: skip
257+
bias_correction2_sq = math.sqrt(1.0 - beta2 ** group['step']) # fmt: skip
269258

270-
# stable decay
271-
decay = group['weight_decay']
272-
if decay:
273-
p.mul_(1.0 - decay * lr / variance_normalized)
259+
noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2) # fmt: skip
274260

275-
# norm loss
276-
u_norm = unit_norm(p)
277-
correction = 2.0 * self.norm_loss_factor * (1.0 - torch.div(1, u_norm + self.eps))
278-
p.mul_(1.0 - lr * correction)
261+
# warm up & down
262+
lr = self.warm_up_dampening(lr, group['step'])
263+
lr = self.warm_down(lr, group['step'])
279264

280265
for p in group['params']:
281266
if p.grad is None:
282267
continue
283268

269+
# stable weight decay
270+
if group['weight_decay']:
271+
p.mul_(1.0 - group['weight_decay'] * lr / variance_normalized)
272+
273+
# norm loss
274+
correction = 2.0 * self.norm_loss_factor * (1.0 - torch.div(1, unit_norm(p) + self.eps))
275+
p.mul_(1.0 - lr * correction)
276+
284277
state = self.state[p]
285-
if state['step'] % 2 == 1:
278+
if group['step'] % 2 == 1:
286279
grad_ma, neg_grad_ma = state['grad_ma'], state['neg_grad_ma']
287280
else:
288281
grad_ma, neg_grad_ma = state['neg_grad_ma'], state['grad_ma']

tests/test_optimizer_parameters.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,25 @@ def test_safe_fp16_methods():
244244
assert optimizer.loss_scale == 2.0 ** (15 - 1)
245245

246246

247-
def test_ranger21_warm_methods():
247+
def test_ranger21_warm_iterations():
248248
assert Ranger21.build_warm_up_iterations(1000, 0.999) == 220
249249
assert Ranger21.build_warm_up_iterations(4500, 0.999) == 2000
250250
assert Ranger21.build_warm_down_iterations(1000) == 280
251251

252252

253+
def test_ranger21_warm_up_and_down():
254+
param = simple_parameter(require_grad=False)
255+
256+
lr: float = 1e-1
257+
opt = Ranger21([param], num_iterations=500, lr=lr, warm_down_min_lr=3e-5)
258+
259+
assert opt.warm_up_dampening(lr, 100) == 0.09090909090909091
260+
assert opt.warm_up_dampening(lr, 200) == 0.1
261+
assert opt.warm_up_dampening(lr, 300) == 0.1
262+
assert opt.warm_down(lr, 300) == 0.1
263+
assert opt.warm_down(lr, 400) == 0.07093070921985817
264+
265+
253266
@pytest.mark.parametrize('optimizer', ['ranger21', 'adai'])
254267
def test_size_of_parameter(optimizer):
255268
param = simple_parameter(require_grad=False)

0 commit comments

Comments
 (0)