Skip to content

Commit baa65c4

Browse files
authored
Merge pull request #183 from kozistr/feature/prodigy-optimizer
[Feature] Implement Prodigy optimizer
2 parents 944a353 + 9d6f7e3 commit baa65c4

File tree

16 files changed

+367
-142
lines changed

16 files changed

+367
-142
lines changed

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, 54 optimizers, 6 lr schedulers are supported!
19+
| Currently, 55 optimizers, 6 lr schedulers are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -216,6 +216,8 @@ You can check the supported optimizers with below code.
216216
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
217217
| Sophia | *A Scalable Stochastic Second-order Optimizer for Language Model Pre-training* | `github <https://github.com/Liuhong99/Sophia>`__ | `https://arxiv.org/abs/2305.14342 <https://arxiv.org/abs/2305.14342>`__ | `cite <https://github.com/Liuhong99/Sophia>`__ |
218218
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
219+
| Prodigy | *An Expeditiously Adaptive Parameter-Free Learner* | `github <https://github.com/konstmish/prodigy>`__ | `https://arxiv.org/abs/2306.06101 <https://arxiv.org/abs/2306.06101>`__ | `cite <https://github.com/konstmish/prodigy#how-to-cite>`__ |
220+
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
219221

220222
Supported LR Scheduler
221223
----------------------

docs/changelogs/v2.10.1.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
## Change Log
22

3+
### Feature
4+
5+
* Implement Prodigy optimizer (#183)
6+
* [An Expeditiously Adaptive Parameter-Free Learner](https://arxiv.org/abs/2306.06101)
7+
38
### Fix
49

510
* `perturb` isn't multiplied by `-step_size` in SWATS optimizer. (#179)

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,11 @@ SophiaH
488488

489489
.. autoclass:: pytorch_optimizer.SophiaH
490490
:members:
491+
492+
.. _Prodigy:
493+
494+
Prodigy
495+
-------
496+
497+
.. autoclass:: pytorch_optimizer.Prodigy
498+
:members:

pyproject.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.10.0"
3+
version = "2.10.1"
44
description = "optimizer & lr scheduler collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -9,7 +9,14 @@ readme = "README.rst"
99
homepage = "https://github.com/kozistr/pytorch_optimizer"
1010
repository = "https://github.com/kozistr/pytorch_optimizer"
1111
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
12-
keywords = ["pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP", "AdamS", "Adan", "AggMo", "AliG", "Amos", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PID", "PNM", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad"]
12+
keywords = [
13+
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
14+
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
15+
"AdamS", "Adan", "AggMo", "AliG", "Amos", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan",
16+
"DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
17+
"PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SignSGD", "SM3",
18+
"SopihaH", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad",
19+
]
1320
classifiers = [
1421
"License :: OSI Approved :: Apache Software License",
1522
"Development Status :: 5 - Production/Stable",

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from pytorch_optimizer.optimizer.pcgrad import PCGrad
5757
from pytorch_optimizer.optimizer.pid import PID
5858
from pytorch_optimizer.optimizer.pnm import PNM
59+
from pytorch_optimizer.optimizer.prodigy import Prodigy
5960
from pytorch_optimizer.optimizer.qhadam import QHAdam
6061
from pytorch_optimizer.optimizer.qhm import QHM
6162
from pytorch_optimizer.optimizer.radam import RAdam
@@ -152,6 +153,7 @@
152153
AdaHessian,
153154
SophiaH,
154155
SignSGD,
156+
Prodigy,
155157
]
156158
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
157159

pytorch_optimizer/base/optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def validate_non_negative(x: float, name: str):
232232

233233
@staticmethod
234234
def validate_positive(x: Union[float, int], name: str):
235-
if x < 1:
235+
if x <= 0:
236236
raise ValueError(f'[-] {name} must be positive')
237237

238238
@staticmethod
@@ -265,7 +265,8 @@ def validate_betas(self, betas: BETAS):
265265
if len(betas) < 3:
266266
return
267267

268-
self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type='[]')
268+
if betas[2] is not None:
269+
self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type='[]')
269270

270271
def validate_nus(self, nus: Union[float, Tuple[float, float]]):
271272
if isinstance(nus, float):

pytorch_optimizer/optimizer/adasmooth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9292

9393
self.apply_weight_decay(
9494
p=p,
95-
grad=p.grad,
95+
grad=grad,
9696
lr=group['lr'],
9797
weight_decay=group['weight_decay'],
9898
weight_decouple=group['weight_decouple'],

pytorch_optimizer/optimizer/lars.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
self.validate_learning_rate(lr)
3232
self.validate_non_negative(weight_decay, 'weight_decay')
3333
self.validate_range(momentum, 'momentum', 0.0, 1.0)
34+
self.validate_range(dampening, 'dampening', 0.0, 1.0)
3435
self.validate_non_negative(trust_coefficient, 'trust_coefficient')
3536

3637
defaults: DEFAULTS = {
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import math
2+
from typing import Optional
3+
4+
import torch
5+
from torch.optim.optimizer import Optimizer
6+
7+
from pytorch_optimizer.base.exception import NoSparseGradientError
8+
from pytorch_optimizer.base.optimizer import BaseOptimizer
9+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
10+
11+
12+
class Prodigy(Optimizer, BaseOptimizer):
13+
r"""An Expeditiously Adaptive Parameter-Free Learner.
14+
15+
Leave LR set to 1 unless you encounter instability.
16+
17+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
18+
:param lr: float. learning rate.
19+
:param betas: BETAS. betas.
20+
:param beta3: float. coefficients for computing the Prodidy step-size using running averages. If set to None,
21+
uses the value of square root of beta2.
22+
:param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
23+
:param d_coef: float. Coefficient in the expression for the estimate of d.
24+
:param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
25+
:param weight_decay: float. weight decay (L2 penalty).
26+
:param weight_decouple: bool. use AdamW style weight decay.
27+
:param fixed_decay: bool. fix weight decay.
28+
:param bias_correction: bool. turn on Adam's bias correction.
29+
:param safeguard_warmup: bool. remove lr from the denominator of D estimate to avoid issues during warm-up stage.
30+
:param eps: float. term added to the denominator to improve numerical stability.
31+
"""
32+
33+
def __init__(
34+
self,
35+
params: PARAMETERS,
36+
lr: float = 1.0,
37+
betas: BETAS = (0.9, 0.999),
38+
beta3: Optional[float] = None,
39+
d0: float = 1e-6,
40+
d_coef: float = 1.0,
41+
growth_rate: float = float('inf'),
42+
weight_decay: float = 0.0,
43+
weight_decouple: bool = True,
44+
fixed_decay: bool = False,
45+
bias_correction: bool = False,
46+
safeguard_warmup: bool = False,
47+
eps: float = 1e-8,
48+
):
49+
self.validate_learning_rate(lr)
50+
self.validate_betas((*betas, beta3))
51+
self.validate_non_negative(weight_decay, 'weight_decay')
52+
self.validate_non_negative(eps, 'eps')
53+
54+
defaults: DEFAULTS = {
55+
'lr': lr,
56+
'betas': betas,
57+
'beta3': beta3,
58+
'd': d0,
59+
'd0': d0,
60+
'd_max': d0,
61+
'd_coef': d_coef,
62+
'growth_rate': growth_rate,
63+
'weight_decay': weight_decay,
64+
'weight_decouple': weight_decouple,
65+
'fixed_decay': fixed_decay,
66+
'bias_correction': bias_correction,
67+
'safeguard_warmup': safeguard_warmup,
68+
'step': 1,
69+
'eps': eps,
70+
}
71+
super().__init__(params, defaults)
72+
73+
def __str__(self) -> str:
74+
return 'Prodigy'
75+
76+
@torch.no_grad()
77+
def reset(self):
78+
for group in self.param_groups:
79+
group['step'] = 1
80+
for p in group['params']:
81+
if p.grad is None:
82+
continue
83+
84+
state = self.state[p]
85+
86+
state['s'] = torch.zeros_like(p)
87+
state['exp_avg'] = torch.zeros_like(p)
88+
state['exp_avg_sq'] = torch.zeros_like(p)
89+
90+
@torch.no_grad()
91+
def step(self, closure: CLOSURE = None) -> LOSS:
92+
loss: LOSS = None
93+
if closure is not None:
94+
with torch.enable_grad():
95+
loss = closure()
96+
97+
group = self.param_groups[0]
98+
device = group['params'][0].device
99+
100+
d_de_nom = torch.tensor([0.0], device=device)
101+
102+
beta1, beta2 = group['betas']
103+
beta3 = group['beta3'] if group['beta3'] is not None else math.sqrt(beta2)
104+
105+
bias_correction1: float = 1.0 - beta1 ** group['step']
106+
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
107+
bias_correction: float = (bias_correction1 / bias_correction2_sq) if group['bias_correction'] else 1.0
108+
109+
d, d0 = group['d'], group['d0']
110+
d_lr: float = d * group['lr'] / bias_correction
111+
112+
if 'd_numerator' not in group:
113+
group['d_numerator'] = torch.tensor([0.0], device=device)
114+
115+
d_numerator = group['d_numerator']
116+
d_numerator.mul_(beta3)
117+
118+
for group in self.param_groups:
119+
for p in group['params']:
120+
if p.grad is None:
121+
continue
122+
123+
grad = p.grad
124+
if grad.is_sparse:
125+
raise NoSparseGradientError(str(self))
126+
127+
state = self.state[p]
128+
if len(state) == 0:
129+
state['s'] = torch.zeros_like(p)
130+
state['p0'] = p.clone()
131+
state['exp_avg'] = torch.zeros_like(p)
132+
state['exp_avg_sq'] = torch.zeros_like(p)
133+
134+
p0, exp_avg, exp_avg_sq = state['p0'], state['exp_avg'], state['exp_avg_sq']
135+
136+
d_numerator.add_(torch.dot(grad.flatten(), (p0 - p).flatten()), alpha=(d / d0) * d_lr)
137+
138+
exp_avg.mul_(beta1).add_(grad, alpha=d * (1.0 - beta1))
139+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1.0 - beta2))
140+
141+
s = state['s']
142+
s.mul_(beta3).add_(grad, alpha=(d / d0) * (d if group['safeguard_warmup'] else d_lr))
143+
144+
d_de_nom.add_(s.abs().sum())
145+
146+
if d_de_nom == 0:
147+
return loss
148+
149+
d_hat = (group['d_coef'] * d_numerator / d_de_nom).item()
150+
if d == group['d0']:
151+
d = max(d, d_hat)
152+
153+
d_max = max(group['d_max'], d_hat)
154+
d = min(d_max, d * group['growth_rate'])
155+
156+
for group in self.param_groups:
157+
group['step'] += 1
158+
159+
group['d_numerator'] = d_numerator
160+
group['d_de_nom'] = d_de_nom
161+
group['d'] = d
162+
group['d_max'] = d_max
163+
group['d_hat'] = d_hat
164+
165+
for p in group['params']:
166+
if p.grad is None:
167+
continue
168+
169+
state = self.state[p]
170+
171+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
172+
173+
de_nom = exp_avg_sq.sqrt().add_(d * group['eps'])
174+
175+
self.apply_weight_decay(
176+
p,
177+
p.grad,
178+
lr=d_lr,
179+
weight_decay=group['weight_decay'],
180+
weight_decouple=group['weight_decouple'],
181+
fixed_decay=group['fixed_decay'],
182+
)
183+
184+
p.addcdiv_(exp_avg, de_nom, value=-d_lr)
185+
186+
return loss

pytorch_optimizer/optimizer/ranger.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@ def __init__(
3030
self,
3131
params: PARAMETERS,
3232
lr: float = 1e-3,
33+
betas: BETAS = (0.95, 0.999),
3334
alpha: float = 0.5,
3435
k: int = 6,
3536
n_sma_threshold: int = 5,
3637
degenerated_to_sgd: bool = False,
37-
betas: BETAS = (0.95, 0.999),
38-
eps: float = 1e-5,
3938
weight_decay: float = 0.0,
4039
weight_decouple: bool = True,
4140
fixed_decay: bool = False,
@@ -44,11 +43,13 @@ def __init__(
4443
r: float = 0.95,
4544
adanorm: bool = False,
4645
adam_debias: bool = False,
46+
eps: float = 1e-5,
4747
):
4848
self.validate_learning_rate(lr)
4949
self.validate_betas(betas)
50-
self.validate_non_negative(weight_decay, 'weight_decay')
50+
self.validate_range(alpha, 'alpha', 0.0, 1.0, range_type='[]')
5151
self.validate_positive(k, 'k')
52+
self.validate_non_negative(weight_decay, 'weight_decay')
5253
self.validate_non_negative(eps, 'eps')
5354

5455
self.n_sma_threshold = n_sma_threshold

0 commit comments

Comments
 (0)