Skip to content

Commit 2258885

Browse files
authored
Merge pull request #80 from kozistr/feature/ranger21-optimizer
[Feature] Ranger21 with AdamD
2 parents c6d64ef + 0290558 commit 2258885

File tree

3 files changed

+42
-14
lines changed

3 files changed

+42
-14
lines changed

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,6 @@
1111
from pytorch_optimizer.optimizer.gc import centralize_gradient
1212
from pytorch_optimizer.optimizer.utils import normalize_gradient, unit_norm
1313

14-
__AUTHORS__ = [
15-
'@lessw2020',
16-
'@NestorDemeure',
17-
# with contributions from :
18-
'@BrianPugh',
19-
'@Kayuksel',
20-
'@TheZothen',
21-
]
22-
2314

2415
class Ranger21(Optimizer, BaseOptimizer):
2516
"""
@@ -38,7 +29,7 @@ class Ranger21(Optimizer, BaseOptimizer):
3829
optimizer.step()
3930
"""
4031

41-
def __init__(
32+
def __init__( # pylint: disable=R0913
4233
self,
4334
params: PARAMETERS,
4435
num_iterations: int,
@@ -58,6 +49,7 @@ def __init__(
5849
lookahead_blending_alpha: float = 0.5,
5950
weight_decay: float = 1e-4,
6051
norm_loss_factor: float = 1e-4,
52+
adamd_debias_term: bool = False,
6153
eps: float = 1e-8,
6254
):
6355
"""Ranger21 optimizer
@@ -76,6 +68,7 @@ def __init__(
7668
:param lookahead_blending_alpha: float. blending alpha
7769
:param weight_decay: float. weight decay (L2 penalty)
7870
:param norm_loss_factor: float. norm loss factor
71+
:param adamd_debias_term: bool.Only correct the denominator to avoid inflating step sizes early in training
7972
:param eps: float. term added to the denominator to improve numerical stability
8073
"""
8174
self.lr = lr
@@ -91,6 +84,7 @@ def __init__(
9184
self.lookahead_blending_alpha = lookahead_blending_alpha
9285
self.weight_decay = weight_decay
9386
self.norm_loss_factor = norm_loss_factor
87+
self.adamd_debias_term = adamd_debias_term
9488
self.eps = eps
9589

9690
self.validate_parameters()
@@ -108,6 +102,7 @@ def __init__(
108102
betas=betas,
109103
eps=eps,
110104
weight_decay=weight_decay,
105+
adamd_debias_term=adamd_debias_term,
111106
)
112107
super().__init__(params, defaults)
113108

@@ -240,6 +235,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
240235
variance_ma_sum += (variance_ma / bias_correction2).sum()
241236

242237
# stable weight decay
238+
if param_size == 0:
239+
raise ValueError('[-] size of parameter is 0')
240+
243241
variance_normalized = math.sqrt(variance_ma_sum / param_size)
244242
if math.isnan(variance_normalized):
245243
raise RuntimeError('hit nan for variance_normalized')
@@ -299,7 +297,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
299297

300298
noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2)
301299

302-
step_size: float = lr / bias_correction1
300+
step_size: float = lr
301+
if not group['adamd_debias_term']:
302+
step_size /= bias_correction1
303303

304304
if self.use_softplus:
305305
de_nom = F.softplus(de_nom, beta=self.beta_softplus)

tests/test_optimizer_parameters.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import List
22

33
import pytest
4+
import torch
45
from torch import nn
6+
from torch.nn import functional as F
57

68
from pytorch_optimizer import SAM, AdamP, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer
79
from tests.utils import Example
@@ -205,7 +207,29 @@ def test_safe_fp16_methods():
205207
assert optimizer.loss_scale == 2.0 ** (15 - 1)
206208

207209

208-
def test_ranger21_methods():
210+
def test_ranger21_warm_methods():
209211
assert Ranger21.build_warm_up_iterations(1000, 0.999) == 220
210212
assert Ranger21.build_warm_up_iterations(4500, 0.999) == 2000
211213
assert Ranger21.build_warm_down_iterations(1000) == 280
214+
215+
216+
def test_ranger21_size_of_parameter():
217+
model: nn.Module = nn.Linear(1, 1, bias=False)
218+
model.requires_grad_(False)
219+
220+
with pytest.raises(ValueError):
221+
Ranger21(model.parameters(), 100).step()
222+
223+
224+
def test_ranger21_closure():
225+
model: nn.Module = Example()
226+
optimizer = Ranger21(model.parameters(), num_iterations=100, betas=(0.9, 1e-9))
227+
228+
loss_fn = nn.BCEWithLogitsLoss()
229+
230+
def closure():
231+
loss = loss_fn(torch.ones((1, 1)), model(torch.ones((1, 1))))
232+
loss.backward()
233+
return loss
234+
235+
optimizer.step(closure)

tests/test_optimizers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
(RaLamb, {'lr': 1e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 500),
9898
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
9999
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
100+
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
100101
]
101102

102103

@@ -247,11 +248,14 @@ def closure():
247248
def test_adamd_optimizers(optimizer_adamd_config):
248249
(x_data, y_data), model, loss_fn = build_environment()
249250

250-
optimizer_class, config, iterations = optimizer_adamd_config
251+
optimizer_class, config, num_iterations = optimizer_adamd_config
252+
if optimizer_class.__name__ == 'Ranger21':
253+
config.update({'num_iterations': num_iterations})
254+
251255
optimizer = optimizer_class(model.parameters(), **config)
252256

253257
init_loss, loss = np.inf, np.inf
254-
for _ in range(iterations):
258+
for _ in range(num_iterations):
255259
optimizer.zero_grad()
256260

257261
y_pred = model(x_data)

0 commit comments

Comments
 (0)