Skip to content

Commit 9e637b2

Browse files
authored
Merge pull request #40 from kozistr/feature/lamb-optimizer
[Feature] Implement LAMB optimizer
2 parents d39b528 + 42d4d2f commit 9e637b2

File tree

6 files changed

+156
-7
lines changed

6 files changed

+156
-7
lines changed

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ Supported Optimizers
7272
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
7373
| Ranger21 | *a synergistic deep learning optimizer* | `github <https://github.com/lessw2020/Ranger21>`__ | `https://arxiv.org/abs/2106.13731 <https://arxiv.org/abs/2106.13731>`__ |
7474
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
75+
| Lamb | *Large Batch Optimization for Deep Learning* | `github <https://github.com/cybertronai/pytorch-lamb>`__ | `https://arxiv.org/abs/1904.00962 <https://arxiv.org/abs/1904.00962>`__ |
76+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
7577

7678
Useful Resources
7779
----------------

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytorch_optimizer.diffrgrad import DiffRGrad
1010
from pytorch_optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
1111
from pytorch_optimizer.gc import centralize_gradient
12+
from pytorch_optimizer.lamb import Lamb
1213
from pytorch_optimizer.lookahead import Lookahead
1314
from pytorch_optimizer.madgrad import MADGRAD
1415
from pytorch_optimizer.optimizers import load_optimizers

pytorch_optimizer/lamb.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import math
2+
3+
import torch
4+
from torch.optim import Optimizer
5+
6+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, PARAMETERS
7+
8+
9+
class Lamb(Optimizer):
10+
"""
11+
Reference : https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
12+
Example :
13+
from pytorch_optimizer import Lamb
14+
...
15+
model = YourModel()
16+
optimizer = Lamb(model.parameters())
17+
...
18+
for input, output in data:
19+
optimizer.zero_grad()
20+
loss = loss_function(output, model(input))
21+
loss.backward()
22+
optimizer.step()
23+
"""
24+
25+
clamp: float = 10.0
26+
27+
def __init__(
28+
self,
29+
params: PARAMETERS,
30+
lr: float = 1e-3,
31+
betas: BETAS = (0.9, 0.999),
32+
eps: float = 1e-6,
33+
weight_decay: float = 0.0,
34+
adam: bool = False,
35+
adamd_debias_term: bool = False,
36+
pre_norm: bool = False,
37+
):
38+
"""
39+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
40+
:param lr: float. learning rate
41+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
42+
:param eps: float. term added to the denominator to improve numerical stability
43+
:param weight_decay: float. weight decay (L2 penalty)
44+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
45+
:param pre_norm: bool. perform pre-normalization of all gradients
46+
"""
47+
self.lr = lr
48+
self.betas = betas
49+
self.weight_decay = weight_decay
50+
self.eps = eps
51+
self.adam = adam
52+
self.adamd_debias_term = adamd_debias_term
53+
self.pre_norm = pre_norm
54+
55+
self.check_valid_parameters()
56+
57+
defaults: DEFAULTS = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
58+
59+
super().__init__(params, defaults)
60+
61+
def check_valid_parameters(self):
62+
if self.lr < 0.0:
63+
raise ValueError(f'Invalid learning rate : {self.lr}')
64+
if not 0.0 <= self.betas[0] < 1.0:
65+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
66+
if not 0.0 <= self.betas[1] < 1.0:
67+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
68+
if self.weight_decay < 0.0:
69+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
70+
if self.eps < 0.0:
71+
raise ValueError(f'Invalid eps : {self.eps}')
72+
73+
def get_gradient_norm(self) -> float:
74+
norm_sq: float = 0.0
75+
for group in self.param_groups:
76+
for p in group['params']:
77+
if p.grad is None:
78+
continue
79+
80+
norm_sq += torch.linalg.norm(p.grad).item() ** 2
81+
82+
norm = math.sqrt(norm_sq)
83+
84+
return norm
85+
86+
def step(self, closure: CLOSURE = None) -> float:
87+
loss = None
88+
if closure is not None:
89+
loss = closure()
90+
91+
grad_norm: float = 1.0
92+
if self.pre_norm:
93+
grad_norm = self.get_gradient_norm()
94+
95+
for group in self.param_groups:
96+
for p in group['params']:
97+
if p.grad is None:
98+
continue
99+
100+
if self.pre_norm:
101+
p.grad /= grad_norm
102+
103+
grad = p.grad.data
104+
if grad.is_sparse:
105+
raise RuntimeError('[-] Lamb does not support sparse gradients, consider SparseAdam instead.')
106+
107+
state = self.state[p]
108+
109+
if len(state) == 0:
110+
state['step'] = 0
111+
state['exp_avg'] = torch.zeros_like(p.data)
112+
state['exp_avg_sq'] = torch.zeros_like(p.data)
113+
114+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
115+
beta1, beta2 = group['betas']
116+
117+
state['step'] += 1
118+
119+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
120+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
121+
122+
step_size = group['lr']
123+
124+
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, self.clamp)
125+
126+
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
127+
if group['weight_decay'] != 0:
128+
adam_step.add_(p.data, alpha=group['weight_decay'])
129+
130+
adam_norm = adam_step.pow(2).sum().sqrt()
131+
if weight_norm == 0 or adam_norm == 0:
132+
trust_ratio = 1.0
133+
else:
134+
trust_ratio = weight_norm / adam_norm
135+
136+
state['weight_norm'] = weight_norm
137+
state['adam_norm'] = adam_norm
138+
state['trust_ratio'] = trust_ratio
139+
140+
if self.adam:
141+
trust_ratio = 1.0
142+
143+
p.data.add_(adam_step, alpha=-step_size * trust_ratio)
144+
145+
return loss

pytorch_optimizer/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def has_overflow(grad_norm: torch.Tensor) -> bool:
1818

1919
def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8) -> torch.Tensor:
2020
"""normalize gradient with stddev
21-
:param x: torch.Tensor. gradient.
22-
:param use_channels: bool. channel-wise normalization.
23-
:param epsilon: float. eps.
21+
:param x: torch.Tensor. gradient
22+
:param use_channels: bool. channel-wise normalization
23+
:param epsilon: float. eps
2424
:return: torch.Tensor. normalized gradient.
2525
"""
2626
size: int = x.dim()
@@ -36,12 +36,12 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
3636
def clip_grad_norm(parameters: PARAMETERS, max_norm: float = 0, sync: bool = False) -> torch.Tensor:
3737
"""Clips grad norms.
3838
During combination with FSDP, will also ensure that grad norms are aggregated
39-
across all workers, since each worker only stores their shard of the gradients.
39+
across all workers, since each worker only stores their shard of the gradients
4040
:param parameters: Parameters whose gradients we wish to clip
4141
:param max_norm: Maximum norm we wish the gradients to have. If non-positive, then
42-
we will not perform clipping.
42+
we will not perform clipping
4343
:param sync: Boolean indicating whether we should aggregate across the distributed
44-
group. Used only in combination with FSDP.
44+
group. Used only in combination with FSDP
4545
:returns: The gradient norm across all parameters, before clipping.
4646
"""
4747
if isinstance(parameters, torch.Tensor):

pytorch_optimizer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__VERSION__ = '0.2.2'
1+
__VERSION__ = '0.3.0'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def read_version() -> str:
6767
'diffrgrad',
6868
'pcgrad',
6969
'adamd',
70+
'lamb',
7071
]
7172
)
7273

0 commit comments

Comments
 (0)