Skip to content

Commit af7aeef

Browse files
committed
feature: Lion optimizer
1 parent a97f597 commit af7aeef

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import math
2+
from typing import List
3+
4+
import torch
5+
from torch.optim.optimizer import Optimizer
6+
7+
from pytorch_optimizer.base.exception import NoSparseGradientError
8+
from pytorch_optimizer.base.optimizer import BaseOptimizer
9+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
10+
11+
12+
class Lion(Optimizer, BaseOptimizer):
13+
r"""Symbolic Discovery of Optimization Algorithms.
14+
15+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
16+
:param lr: float. learning rate.
17+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
18+
:param weight_decay: float. weight decay (L2 penalty).
19+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
20+
"""
21+
22+
def __init__(
23+
self,
24+
params: PARAMETERS,
25+
lr: float = 1e-4,
26+
betas: BETAS = (0.9, 0.99),
27+
weight_decay: float = 0.0,
28+
weight_decouple: bool = True,
29+
):
30+
self.lr = lr
31+
self.betas = betas
32+
self.weight_decay = weight_decay
33+
self.weight_decouple = weight_decouple
34+
35+
self.validate_parameters()
36+
37+
defaults: DEFAULTS = {
38+
'lr': lr,
39+
'betas': betas,
40+
'weight_decay': weight_decay,
41+
}
42+
super().__init__(params, defaults)
43+
44+
def validate_parameters(self):
45+
self.validate_learning_rate(self.lr)
46+
self.validate_betas(self.betas)
47+
self.validate_weight_decay(self.weight_decay)
48+
49+
@property
50+
def __str__(self) -> str:
51+
return 'Lion'
52+
53+
@torch.no_grad()
54+
def reset(self):
55+
for group in self.param_groups:
56+
for p in group['params']:
57+
state = self.state[p]
58+
59+
state['step'] = 0
60+
state['exp_avg'] = torch.zeros_like(p)
61+
62+
@torch.no_grad()
63+
def step(self, closure: CLOSURE = None) -> LOSS:
64+
loss: LOSS = None
65+
if closure is not None:
66+
with torch.enable_grad():
67+
loss = closure()
68+
69+
for group, base_lr in self.param_groups:
70+
beta1, beta2 = group['betas']
71+
weight_decay = group['weight_decay']
72+
for p in group['params']:
73+
if p.grad is None:
74+
continue
75+
76+
grad = p.grad
77+
if grad.is_sparse:
78+
raise NoSparseGradientError(self.__str__)
79+
80+
state = self.state[p]
81+
82+
if len(state) == 0:
83+
state['step'] = 0
84+
state['exp_avg'] = torch.zeros_like(p)
85+
86+
state['step'] += 1
87+
update = exp_avg = state['exp_avg']
88+
89+
if weight_decay > 0.0:
90+
if self.weight_decouple:
91+
p.mul_(1.0 - group['lr'] * weight_decay)
92+
else:
93+
grad.add_(p, alpha=weight_decay)
94+
95+
update.mul_(beta1).add_(grad, alpha=1.0 - beta1)
96+
exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2)
97+
98+
p.add_(update.sign(), alpha=-group['lr'])
99+
100+
return loss

0 commit comments

Comments
 (0)