Skip to content

Commit 3cd5158

Browse files
authored
Merge pull request #50 from kozistr/feature/lars-optimizer
[Feature] Implement LARS optimizer
2 parents 16aeb2c + 9fadf57 commit 3cd5158

File tree

9 files changed

+117
-1
lines changed

9 files changed

+117
-1
lines changed

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytorch_optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
1111
from pytorch_optimizer.gc import centralize_gradient
1212
from pytorch_optimizer.lamb import Lamb
13+
from pytorch_optimizer.lars import LARS
1314
from pytorch_optimizer.lookahead import Lookahead
1415
from pytorch_optimizer.madgrad import MADGRAD
1516
from pytorch_optimizer.optimizers import load_optimizers

pytorch_optimizer/lars.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
from torch.optim import Optimizer
3+
4+
from pytorch_optimizer.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
5+
6+
7+
class LARS(Optimizer):
8+
"""
9+
Reference : https://github.com/facebookresearch/mae/blob/main/util/lars.py
10+
Example :
11+
from pytorch_optimizer import LARS
12+
...
13+
model = YourModel()
14+
optimizer = LARS(model.parameters())
15+
...
16+
for input, output in data:
17+
optimizer.zero_grad()
18+
loss = loss_function(output, model(input))
19+
loss.backward()
20+
optimizer.step()
21+
"""
22+
23+
def __init__(
24+
self,
25+
params: PARAMETERS,
26+
lr: float = 1e-3,
27+
weight_decay: float = 0.0,
28+
momentum: float = 0.9,
29+
trust_coefficient: float = 0.001,
30+
eps: float = 1e-6,
31+
):
32+
"""LARS optimizer, no rate scaling or weight decay for parameters <= 1D
33+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
34+
:param lr: float. learning rate
35+
:param weight_decay: float. weight decay (L2 penalty)
36+
:param momentum: float. momentum
37+
:param trust_coefficient: float. trust_coefficient
38+
:param eps: float. epsilon
39+
"""
40+
self.lr = lr
41+
self.weight_decay = weight_decay
42+
self.momentum = momentum
43+
self.trust_coefficient = trust_coefficient
44+
self.eps = eps
45+
46+
self.check_valid_parameters()
47+
48+
defaults: DEFAULTS = dict(
49+
lr=lr,
50+
weight_decay=weight_decay,
51+
momentum=momentum,
52+
trust_coefficient=trust_coefficient,
53+
)
54+
super().__init__(params, defaults)
55+
56+
def check_valid_parameters(self):
57+
if self.lr < 0.0:
58+
raise ValueError(f'Invalid learning rate : {self.lr}')
59+
if self.weight_decay < 0.0:
60+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
61+
if self.momentum < 0.0:
62+
raise ValueError(f'Invalid momentum : {self.momentum}')
63+
if self.trust_coefficient < 0.0:
64+
raise ValueError(f'Invalid trust_coefficient : {self.trust_coefficient}')
65+
if self.eps < 0.0:
66+
raise ValueError(f'Invalid eps : {self.eps}')
67+
68+
@torch.no_grad()
69+
def step(self, closure: CLOSURE = None) -> LOSS:
70+
loss: LOSS = None
71+
if closure is not None:
72+
loss = closure()
73+
74+
for g in self.param_groups:
75+
for p in g['params']:
76+
if p.grad is None:
77+
continue
78+
79+
if p.grad.data.is_sparse:
80+
raise RuntimeError('LARS does not support sparse gradients')
81+
82+
dp = p.grad
83+
84+
if p.ndim > 1: # if not normalization gamma/beta or bias
85+
dp = dp.add(p, alpha=g['weight_decay'])
86+
param_norm = torch.norm(p)
87+
update_norm = torch.norm(dp)
88+
one = torch.ones_like(param_norm)
89+
90+
q = torch.where(
91+
param_norm > 0.0,
92+
torch.where(update_norm > 0.0, (g['trust_coefficient'] * param_norm / update_norm), one),
93+
one,
94+
)
95+
dp = dp.mul(q)
96+
97+
param_state = self.state[p]
98+
if 'mu' not in param_state:
99+
param_state['mu'] = torch.zeros_like(p)
100+
101+
mu = param_state['mu']
102+
mu.mul_(g['momentum']).add_(dp)
103+
104+
p.add_(mu, alpha=-g['lr'])
105+
106+
return loss

pytorch_optimizer/optimizers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytorch_optimizer.diffrgrad import DiffRGrad
77
from pytorch_optimizer.fp16 import SafeFP16Optimizer
88
from pytorch_optimizer.lamb import Lamb
9+
from pytorch_optimizer.lars import LARS
910
from pytorch_optimizer.madgrad import MADGRAD
1011
from pytorch_optimizer.radam import RAdam
1112
from pytorch_optimizer.ralamb import RaLamb
@@ -45,6 +46,8 @@ def load_optimizers(optimizer: str, use_fp16: bool = False):
4546
opt = Lamb
4647
elif optimizer == 'ralamb':
4748
opt = RaLamb
49+
elif optimizer == 'lars':
50+
opt = LARS
4851
else:
4952
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
5053

pytorch_optimizer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__VERSION__ = '0.3.6'
1+
__VERSION__ = '0.3.7'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def read_version() -> str:
6969
'adamd',
7070
'lamb',
7171
'ralamb',
72+
'lars',
7273
]
7374
)
7475

tests/test_load_optimizers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
'diffrgrad',
1919
'lamb',
2020
'ralamb',
21+
'lars',
2122
]
2223

2324
INVALID_OPTIMIZER_NAMES: List[str] = [

tests/test_optimizer_parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3030
'diffrgrad',
3131
'lamb',
3232
'ralamb',
33+
'lars',
3334
]
3435

3536
BETA_OPTIMIZER_NAMES: List[str] = [

tests/test_optimizers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.nn import functional as F
88

99
from pytorch_optimizer import (
10+
LARS,
1011
MADGRAD,
1112
SAM,
1213
SGDP,
@@ -94,6 +95,7 @@ def build_lookahead(*parameters, **kwargs):
9495
(DiffRGrad, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
9596
(Lamb, {'lr': 1e-1, 'weight_decay': 1e-3}, 500),
9697
(Lamb, {'lr': 1e-1, 'weight_decay': 1e-3, 'pre_norm': True, 'eps': 1e-8}, 500),
98+
(LARS, {'lr': 1e-1, 'weight_decay': 1e-3}, 500),
9799
(RaLamb, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
98100
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3}, 500),
99101
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),

tests/test_sparse_gradient.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
'diffrgrad',
2323
'lamb',
2424
'ralamb',
25+
'lars',
2526
]
2627

2728

0 commit comments

Comments
 (0)