Skip to content

Commit 6a74fdc

Browse files
committed
add lion with cautious update, from Liang et al.
1 parent 6d093f1 commit 6a74fdc

File tree

4 files changed

+104
-2
lines changed

4 files changed

+104
-2
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,12 @@ opt = Lion(
109109
url = {https://fabian-sp.github.io/posts/2024/02/decoupling/}
110110
}
111111
```
112+
113+
```bibtex
114+
@inproceedings{Liang2024CautiousOI,
115+
title = {Cautious Optimizers: Improving Training with One Line of Code},
116+
author = {Kaizhao Liang and Lizhang Chen and Bo Liu and Qiang Liu},
117+
year = {2024},
118+
url = {https://api.semanticscholar.org/CorpusID:274234738}
119+
}
120+
```

lion_pytorch/cautious_lion.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
from typing import Tuple, Callable
3+
4+
import torch
5+
from torch.optim.optimizer import Optimizer
6+
7+
# functions
8+
9+
def exists(val):
10+
return val is not None
11+
12+
# class
13+
14+
class Lion(Optimizer):
15+
def __init__(
16+
self,
17+
params,
18+
lr: float = 1e-4,
19+
betas: Tuple[float, float] = (0.9, 0.99),
20+
weight_decay: float = 0.0,
21+
cautious_factor: float = 0.,
22+
decoupled_weight_decay: bool = False,
23+
):
24+
assert lr > 0.
25+
assert all([0. <= beta <= 1. for beta in betas])
26+
assert 0. <= cautious_factor <= 1.
27+
28+
self._init_lr = lr
29+
self.decoupled_wd = decoupled_weight_decay
30+
31+
defaults = dict(
32+
lr = lr,
33+
betas = betas,
34+
weight_decay = weight_decay,
35+
cautious_factor = cautious_factor
36+
)
37+
38+
super().__init__(params, defaults)
39+
40+
@torch.no_grad()
41+
def step(
42+
self,
43+
closure: Callable | None = None
44+
):
45+
46+
loss = None
47+
if exists(closure):
48+
with torch.enable_grad():
49+
loss = closure()
50+
51+
for group in self.param_groups:
52+
for p in filter(lambda p: exists(p.grad), group['params']):
53+
54+
grad, lr, wd, cautious_factor, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], group['cautious_factor'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr
55+
56+
# maybe decoupled weight decay
57+
58+
if decoupled_wd:
59+
wd /= init_lr
60+
61+
# init state - exponential moving average of gradient values
62+
63+
if len(state) == 0:
64+
state['exp_avg'] = torch.zeros_like(p)
65+
66+
exp_avg = state['exp_avg']
67+
68+
# stepweight decay
69+
70+
p.data.mul_(1. - lr * wd)
71+
72+
# weight update
73+
74+
update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_()
75+
76+
# maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085
77+
78+
if cautious_factor < 1.:
79+
align_mask = (update * grad) > 0
80+
scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
81+
scale /= scale.mean().clamp(min = 1e-5)
82+
update.mul_(scale)
83+
84+
# update params
85+
86+
p.add_(update, alpha = -lr)
87+
88+
# decay the momentum running average coefficient
89+
90+
exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2)
91+
92+
return loss

lion_pytorch/foreach.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def step(
7878

7979
# stepweight decay
8080

81-
torch._foreach_mul_(params, 1. - lr * wd)
81+
if wd > 0.:
82+
torch._foreach_mul_(params, 1. - lr * wd)
8283

8384
# weight update
8485

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.2.2',
6+
version = '0.2.3',
77
license='MIT',
88
description = 'Lion Optimizer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)