Skip to content

Commit 56c05b9

Browse files
authored
Merge pull request #46 from kozistr/feature/ralamb-optimizer
[Feature] RaLamb optimizer
2 parents 4bf0020 + af69606 commit 56c05b9

File tree

11 files changed

+308
-12
lines changed

11 files changed

+308
-12
lines changed

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytorch_optimizer.optimizers import load_optimizers
1616
from pytorch_optimizer.pcgrad import PCGrad
1717
from pytorch_optimizer.radam import RAdam
18+
from pytorch_optimizer.ralamb import RaLamb
1819
from pytorch_optimizer.ranger import Ranger
1920
from pytorch_optimizer.ranger21 import Ranger21
2021
from pytorch_optimizer.sam import SAM

pytorch_optimizer/lamb.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
adamd_debias_term: bool = False,
3636
pre_norm: bool = False,
3737
):
38-
"""
38+
"""Lamb
3939
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
4040
:param lr: float. learning rate
4141
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
@@ -121,13 +121,12 @@ def step(self, closure: CLOSURE = None) -> float:
121121

122122
step_size = group['lr']
123123

124-
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, self.clamp)
125-
126124
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
127125
if group['weight_decay'] != 0:
128126
adam_step.add_(p.data, alpha=group['weight_decay'])
129127

130128
adam_norm = adam_step.pow(2).sum().sqrt()
129+
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, self.clamp)
131130
if weight_norm == 0 or adam_norm == 0:
132131
trust_ratio = 1.0
133132
else:

pytorch_optimizer/optimizers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytorch_optimizer.lamb import Lamb
99
from pytorch_optimizer.madgrad import MADGRAD
1010
from pytorch_optimizer.radam import RAdam
11+
from pytorch_optimizer.ralamb import RaLamb
1112
from pytorch_optimizer.ranger import Ranger
1213
from pytorch_optimizer.ranger21 import Ranger21
1314
from pytorch_optimizer.sgdp import SGDP
@@ -42,6 +43,8 @@ def load_optimizers(optimizer: str, use_fp16: bool = False):
4243
opt = AdaHessian
4344
elif optimizer == 'lamb':
4445
opt = Lamb
46+
elif optimizer == 'ralamb':
47+
opt = RaLamb
4548
else:
4649
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
4750

pytorch_optimizer/ralamb.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import math
2+
3+
import torch
4+
from torch.optim import Optimizer
5+
6+
from pytorch_optimizer.types import BETAS, BUFFER, CLOSURE, DEFAULTS, PARAMETERS
7+
from pytorch_optimizer.utils import is_valid_parameters
8+
9+
10+
class RaLamb(Optimizer):
11+
"""
12+
Reference : https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20
13+
Example :
14+
from pytorch_optimizer import RaLamb
15+
...
16+
model = YourModel()
17+
optimizer = RaLamb(model.parameters())
18+
...
19+
for input, output in data:
20+
optimizer.zero_grad()
21+
loss = loss_function(output, model(input))
22+
loss.backward()
23+
optimizer.step()
24+
"""
25+
26+
clamp: float = 10.0
27+
28+
def __init__(
29+
self,
30+
params: PARAMETERS,
31+
lr: float = 1e-3,
32+
betas: BETAS = (0.9, 0.999),
33+
eps: float = 1e-8,
34+
weight_decay: float = 0.0,
35+
adamd_debias_term: bool = False,
36+
pre_norm: bool = False,
37+
n_sma_threshold: int = 5,
38+
degenerated_to_sgd: bool = False,
39+
):
40+
"""RaLamb
41+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
42+
:param lr: float. learning rate
43+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
44+
:param eps: float. term added to the denominator to improve numerical stability
45+
:param weight_decay: float. weight decay (L2 penalty)
46+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
47+
:param pre_norm: bool. perform pre-normalization of all gradients
48+
:param n_sma_threshold: int. (recommended is 5)
49+
:param degenerated_to_sgd: float. degenerated to SGD
50+
"""
51+
self.lr = lr
52+
self.betas = betas
53+
self.weight_decay = weight_decay
54+
self.eps = eps
55+
self.adamd_debias_term = adamd_debias_term
56+
self.pre_norm = pre_norm
57+
self.n_sma_threshold = n_sma_threshold
58+
self.degenerated_to_sgd = degenerated_to_sgd
59+
60+
self.check_valid_parameters()
61+
62+
buffer: BUFFER = [[None, None, None] for _ in range(10)]
63+
64+
if is_valid_parameters(params):
65+
for param in params:
66+
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
67+
param['buffer'] = buffer
68+
69+
defaults: DEFAULTS = dict(
70+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, adamd_debias_term=adamd_debias_term, buffer=buffer
71+
)
72+
73+
super().__init__(params, defaults)
74+
75+
def check_valid_parameters(self):
76+
if self.lr < 0.0:
77+
raise ValueError(f'Invalid learning rate : {self.lr}')
78+
if not 0.0 <= self.betas[0] < 1.0:
79+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
80+
if not 0.0 <= self.betas[1] < 1.0:
81+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
82+
if self.weight_decay < 0.0:
83+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
84+
if self.eps < 0.0:
85+
raise ValueError(f'Invalid eps : {self.eps}')
86+
87+
def get_gradient_norm(self) -> float:
88+
norm_sq: float = 0.0
89+
for group in self.param_groups:
90+
for p in group['params']:
91+
if p.grad is None:
92+
continue
93+
94+
norm_sq += torch.linalg.norm(p.grad).item() ** 2
95+
96+
norm = math.sqrt(norm_sq)
97+
98+
return norm
99+
100+
def step(self, closure: CLOSURE = None) -> float:
101+
loss = None
102+
if closure is not None:
103+
loss = closure()
104+
105+
grad_norm: float = 1.0
106+
if self.pre_norm:
107+
grad_norm = self.get_gradient_norm()
108+
109+
for group in self.param_groups:
110+
for p in group['params']:
111+
if p.grad is None:
112+
continue
113+
114+
if self.pre_norm:
115+
p.grad /= grad_norm
116+
117+
grad = p.grad.data
118+
if grad.is_sparse:
119+
raise RuntimeError('[-] Lamb does not support sparse gradients, consider SparseAdam instead.')
120+
121+
p_data_fp32 = p.data.float()
122+
123+
state = self.state[p]
124+
125+
if len(state) == 0:
126+
state['step'] = 0
127+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
128+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
129+
else:
130+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
131+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
132+
133+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
134+
beta1, beta2 = group['betas']
135+
136+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
137+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
138+
139+
state['step'] += 1
140+
buffered = group['buffer'][int(state['step'] % 10)]
141+
142+
bias_correction1 = 1 - beta1 ** state['step']
143+
144+
if state['step'] == buffered[0]:
145+
n_sma, step_size = buffered[1], buffered[2]
146+
else:
147+
buffered[0] = state['step']
148+
beta2_t = beta2 ** state['step']
149+
n_sma_max = 2 / (1 - beta2) - 1
150+
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
151+
buffered[1] = n_sma
152+
153+
# more conservative since it's an approximated value
154+
if n_sma >= self.n_sma_threshold:
155+
rt = math.sqrt(
156+
(1 - beta2_t)
157+
* (n_sma - 4)
158+
/ (n_sma_max - 4)
159+
* (n_sma - 2)
160+
/ n_sma
161+
* n_sma_max
162+
/ (n_sma_max - 2)
163+
)
164+
165+
if group['adamd_debias_term']:
166+
step_size = rt
167+
else:
168+
step_size = rt / bias_correction1
169+
elif self.degenerated_to_sgd:
170+
step_size = 1.0 / bias_correction1
171+
else:
172+
step_size = group['lr'] / bias_correction1
173+
174+
buffered[2] = step_size
175+
176+
if group['weight_decay'] != 0:
177+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
178+
179+
radam_step = p_data_fp32.clone()
180+
if n_sma >= self.n_sma_threshold:
181+
denom = exp_avg_sq.sqrt().add_(group['eps'])
182+
radam_step.addcdiv_(exp_avg, denom, value=-step_size)
183+
else:
184+
radam_step.add_(exp_avg, alpha=-step_size)
185+
186+
radam_step = radam_step.pow(2).sum().sqrt()
187+
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, self.clamp)
188+
if weight_norm == 0 or radam_step == 0:
189+
trust_ratio = 1.0
190+
else:
191+
trust_ratio = weight_norm / radam_step
192+
193+
state['weight_norm'] = weight_norm
194+
state['adam_norm'] = radam_step
195+
state['trust_ratio'] = trust_ratio
196+
197+
if n_sma >= self.n_sma_threshold:
198+
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * trust_ratio)
199+
else:
200+
p_data_fp32.add_(exp_avg, alpha=-step_size * trust_ratio)
201+
202+
return loss

pytorch_optimizer/ranger21.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
205205
if p.grad is None:
206206
continue
207207

208+
grad = p.grad.data
209+
if grad.is_sparse:
210+
raise RuntimeError('Ranger21 does not support sparse gradients')
211+
208212
param_size += p.numel()
209213

210214
# Apply Adaptive Gradient Clipping (AGC)

pytorch_optimizer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__VERSION__ = '0.3.2'
1+
__VERSION__ = '0.3.3'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def read_version() -> str:
6868
'pcgrad',
6969
'adamd',
7070
'lamb',
71+
'ralamb',
7172
]
7273
)
7374

tests/test_load_optimizers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
'diffgrad',
1818
'diffrgrad',
1919
'lamb',
20+
'ralamb',
2021
]
2122

2223
INVALID_OPTIMIZER_NAMES: List[str] = [

tests/test_optimizer_parameters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
'diffgrad',
1818
'diffrgrad',
1919
'lamb',
20+
'ralamb',
2021
]
2122

2223
BETA_OPTIMIZER_NAMES: List[str] = [
@@ -30,6 +31,7 @@
3031
'radam',
3132
'ranger',
3233
'ranger21',
34+
'ralamb',
3335
]
3436

3537

0 commit comments

Comments
 (0)