Skip to content

Commit b316ef9

Browse files
authored
Merge pull request #251 from kozistr/feature/kate-optimizer
[Feature] Implement Kate optimizer
2 parents 908e82e + 862ec9d commit b316ef9

File tree

11 files changed

+292
-171
lines changed

11 files changed

+292
-171
lines changed

README.md

Lines changed: 75 additions & 74 deletions
Large diffs are not rendered by default.

docs/changelogs/v3.0.2.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
* Implement `WSD` LR Scheduler. (#247, #248)
66
* [Warmup-Stable-Decay LR Scheduler](https://arxiv.org/abs/2404.06395)
77
* Add more Pytorch built-in lr schedulers. (#248)
8+
* Implement `Kate` optimizer. (#249, #251)
9+
* [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648)
810

911
### Refactor
1012

docs/index.md

Lines changed: 75 additions & 74 deletions
Large diffs are not rendered by default.

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@
172172
:docstring:
173173
:members:
174174

175+
::: pytorch_optimizer.Kate
176+
:docstring:
177+
:members:
178+
175179
::: pytorch_optimizer.Lamb
176180
:docstring:
177181
:members:

poetry.lock

Lines changed: 19 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ keywords = [
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
1515
"AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad",
1616
"DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore", "Gravity",
17-
"GrokFast", "GSAM", "LARS", "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam",
18-
"PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
17+
"GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
18+
"PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
1919
"ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH",
2020
"SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM",
2121
"Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from pytorch_optimizer.optimizer.gc import centralize_gradient
6767
from pytorch_optimizer.optimizer.gravity import Gravity
6868
from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW, gradfilter_ema, gradfilter_ma
69+
from pytorch_optimizer.optimizer.kate import Kate
6970
from pytorch_optimizer.optimizer.lamb import Lamb
7071
from pytorch_optimizer.optimizer.lars import LARS
7172
from pytorch_optimizer.optimizer.lion import Lion
@@ -199,6 +200,7 @@
199200
ScheduleFreeAdamW,
200201
FAdam,
201202
GrokFastAdamW,
203+
Kate,
202204
]
203205
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
204206

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base.exception import NoSparseGradientError
5+
from pytorch_optimizer.base.optimizer import BaseOptimizer
6+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
8+
9+
class Kate(Optimizer, BaseOptimizer):
10+
r"""Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad.
11+
12+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
13+
:param lr: float. learning rate.
14+
:param delta: float. delta. 0.0 or 1e-8.
15+
:param weight_decay: float. weight decay (L2 penalty).
16+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
17+
:param fixed_decay: bool. fix weight decay.
18+
:param eps: float. epsilon value.
19+
"""
20+
21+
def __init__(
22+
self,
23+
params: PARAMETERS,
24+
lr: float = 1e-3,
25+
delta: float = 0.0,
26+
weight_decay: float = 0.0,
27+
weight_decouple: bool = True,
28+
fixed_decay: bool = False,
29+
eps: float = 1e-8,
30+
):
31+
self.validate_learning_rate(lr)
32+
self.validate_range(delta, 'delta', 0.0, 1.0, '[)')
33+
self.validate_non_negative(weight_decay, 'weight_decay')
34+
self.validate_non_negative(eps, 'eps')
35+
36+
defaults: DEFAULTS = {
37+
'lr': lr,
38+
'delta': delta,
39+
'weight_decay': weight_decay,
40+
'weight_decouple': weight_decouple,
41+
'fixed_decay': fixed_decay,
42+
'eps': eps,
43+
}
44+
45+
super().__init__(params, defaults)
46+
47+
def __str__(self) -> str:
48+
return 'Kate'
49+
50+
@torch.no_grad()
51+
def reset(self):
52+
for group in self.param_groups:
53+
group['step'] = 0
54+
for p in group['params']:
55+
state = self.state[p]
56+
57+
state['m'] = torch.zeros_like(p)
58+
state['b'] = torch.zeros_like(p)
59+
60+
@torch.no_grad()
61+
def step(self, closure: CLOSURE = None) -> LOSS:
62+
loss: LOSS = None
63+
if closure is not None:
64+
with torch.enable_grad():
65+
loss = closure()
66+
67+
for group in self.param_groups:
68+
if 'step' in group:
69+
group['step'] += 1
70+
else:
71+
group['step'] = 1
72+
73+
for p in group['params']:
74+
if p.grad is None:
75+
continue
76+
77+
grad = p.grad
78+
if grad.is_sparse:
79+
raise NoSparseGradientError(str(self))
80+
81+
state = self.state[p]
82+
83+
if len(state) == 0:
84+
state['m'] = torch.zeros_like(p)
85+
state['b'] = torch.zeros_like(p)
86+
87+
self.apply_weight_decay(
88+
p=p,
89+
grad=p.grad,
90+
lr=group['lr'],
91+
weight_decay=group['weight_decay'],
92+
weight_decouple=group['weight_decouple'],
93+
fixed_decay=group['fixed_decay'],
94+
)
95+
96+
grad_p2 = torch.mul(grad, grad)
97+
98+
m, b = state['m'], state['b']
99+
b.mul_(b).add_(grad_p2).add_(group['eps'])
100+
101+
m.mul_(m).add_(grad_p2, alpha=group['delta']).add_(grad_p2 / b).sqrt_()
102+
103+
update = m.mul(grad).div_(b)
104+
105+
p.add_(update, alpha=-group['lr'])
106+
107+
b.sqrt_()
108+
109+
return loss

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ platformdirs==4.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
2323
pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
2424
pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
2525
pytest==8.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
26-
ruff==0.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
26+
ruff==0.5.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
2727
sympy==1.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
2828
tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows"
2929
tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6"

tests/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
GaLore,
5151
Gravity,
5252
GrokFastAdamW,
53+
Kate,
5354
Lamb,
5455
Lion,
5556
Nero,
@@ -461,6 +462,7 @@
461462
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
462463
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
463464
(GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
465+
(Kate, {'lr': 5e-2}, 10),
464466
]
465467
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
466468
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

0 commit comments

Comments
 (0)