Skip to content

Commit 8259768

Browse files
authored
Merge pull request #186 from kozistr/feature/padam-optimizer
[Feature] Implement PAdam optimizer
2 parents 7dcaabb + 143b722 commit 8259768

File tree

8 files changed

+145
-5
lines changed

8 files changed

+145
-5
lines changed

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, 55 optimizers, 6 lr schedulers are supported!
19+
| Currently, 56 optimizers, 6 lr schedulers are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -216,6 +216,8 @@ You can check the supported optimizers with below code.
216216
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
217217
| Prodigy | *An Expeditiously Adaptive Parameter-Free Learner* | `github <https://github.com/konstmish/prodigy>`__ | `https://arxiv.org/abs/2306.06101 <https://arxiv.org/abs/2306.06101>`__ | `cite <https://github.com/konstmish/prodigy#how-to-cite>`__ |
218218
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
219+
| PAdam | *Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks* | `github <https://github.com/uclaml/Padam>`__ | `https://arxiv.org/abs/1806.06763 <https://arxiv.org/abs/1806.06763>`__ | `cite <https://github.com/uclaml/Padam#citation>`__ |
220+
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
219221

220222
Supported LR Scheduler
221223
----------------------

docs/changelogs/v2.11.0.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
## Change Log
2+
3+
### Feature
4+
5+
* Implement PAdam optimizer (#186)
6+
* [Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks](https://arxiv.org/abs/1806.06763)
7+
8+
### Diff
9+
10+
[2.10.1...2.11.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.10.1...v2.11.0)

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,11 @@ Prodigy
496496

497497
.. autoclass:: pytorch_optimizer.Prodigy
498498
:members:
499+
500+
.. _PAdam:
501+
502+
PAdam
503+
-----
504+
505+
.. autoclass:: pytorch_optimizer.PAdam
506+
:members:

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
1515
"AdamS", "Adan", "AggMo", "AliG", "Amos", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan",
16-
"DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
17-
"PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SignSGD", "SM3",
18-
"SopihaH", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad",
16+
"DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero",
17+
"NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad",
18+
"SAM", "SGDP", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi",
1919
]
2020
classifiers = [
2121
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pytorch_optimizer.optimizer.msvag import MSVAG
5454
from pytorch_optimizer.optimizer.nero import Nero
5555
from pytorch_optimizer.optimizer.novograd import NovoGrad
56+
from pytorch_optimizer.optimizer.padam import PAdam
5657
from pytorch_optimizer.optimizer.pcgrad import PCGrad
5758
from pytorch_optimizer.optimizer.pid import PID
5859
from pytorch_optimizer.optimizer.pnm import PNM
@@ -154,6 +155,7 @@
154155
SophiaH,
155156
SignSGD,
156157
Prodigy,
158+
PAdam,
157159
]
158160
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
159161

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base.exception import NoSparseGradientError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
8+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
10+
11+
class PAdam(Optimizer, BaseOptimizer):
12+
"""Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks.
13+
14+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
15+
:param lr: float. learning rate.
16+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
:param partial: float. partially adaptive parameter.
18+
:param weight_decay: float. weight decay (L2 penalty).
19+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
20+
:param fixed_decay: bool. fix weight decay.
21+
:param eps: float. term added to the denominator to improve numerical stability.
22+
"""
23+
24+
def __init__(
25+
self,
26+
params: PARAMETERS,
27+
lr: float = 1e-1,
28+
betas: BETAS = (0.9, 0.999),
29+
partial: float = 0.25,
30+
weight_decay: float = 0.0,
31+
weight_decouple: bool = False,
32+
fixed_decay: bool = False,
33+
eps: float = 1e-8,
34+
):
35+
self.validate_learning_rate(lr)
36+
self.validate_betas(betas)
37+
self.validate_range(partial, 'partial', 0.0, 1.0, range_type='(]')
38+
self.validate_non_negative(weight_decay, 'weight_decay')
39+
self.validate_non_negative(eps, 'eps')
40+
41+
defaults: DEFAULTS = {
42+
'lr': lr,
43+
'betas': betas,
44+
'partial': partial,
45+
'weight_decay': weight_decay,
46+
'weight_decouple': weight_decouple,
47+
'fixed_decay': fixed_decay,
48+
'eps': eps,
49+
}
50+
super().__init__(params, defaults)
51+
52+
def __str__(self) -> str:
53+
return 'PAdam'
54+
55+
@torch.no_grad()
56+
def reset(self):
57+
for group in self.param_groups:
58+
group['step'] = 0
59+
for p in group['params']:
60+
state = self.state[p]
61+
62+
state['exp_avg'] = torch.zeros_like(p)
63+
state['exp_avg_sq'] = torch.zeros_like(p)
64+
65+
@torch.no_grad()
66+
def step(self, closure: CLOSURE = None) -> LOSS:
67+
loss: LOSS = None
68+
if closure is not None:
69+
with torch.enable_grad():
70+
loss = closure()
71+
72+
for group in self.param_groups:
73+
if 'step' in group:
74+
group['step'] += 1
75+
else:
76+
group['step'] = 1
77+
78+
beta1, beta2 = group['betas']
79+
80+
bias_correction1: float = 1.0 - beta1 ** group['step']
81+
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
82+
83+
for p in group['params']:
84+
if p.grad is None:
85+
continue
86+
87+
grad = p.grad
88+
if grad.is_sparse:
89+
raise NoSparseGradientError(str(self))
90+
91+
state = self.state[p]
92+
if len(state) == 0:
93+
state['exp_avg'] = torch.zeros_like(p)
94+
state['exp_avg_sq'] = torch.zeros_like(p)
95+
96+
self.apply_weight_decay(
97+
p,
98+
grad,
99+
lr=group['lr'],
100+
weight_decay=group['weight_decay'],
101+
weight_decouple=group['weight_decouple'],
102+
fixed_decay=group['fixed_decay'],
103+
)
104+
105+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
106+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
107+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
108+
109+
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
110+
111+
step_size: float = group['lr'] * bias_correction2_sq / bias_correction1
112+
113+
p.addcdiv_(exp_avg, de_nom ** (group['partial'] * 2), value=-step_size)
114+
115+
return loss

tests/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
Lion,
4848
Nero,
4949
NovoGrad,
50+
PAdam,
5051
Prodigy,
5152
QHAdam,
5253
RAdam,
@@ -109,6 +110,7 @@
109110
'adashift',
110111
'sophiah',
111112
'prodigy',
113+
'padam',
112114
]
113115

114116
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -385,6 +387,7 @@
385387
(Prodigy, {'lr': 5e1, 'beta3': 0.999, 'weight_decay': 1e-3}, 10),
386388
(Prodigy, {'lr': 1e1, 'beta3': 0.999, 'weight_decay': 1e-3, 'bias_correction': True}, 15),
387389
(Prodigy, {'lr': 1e0, 'beta3': 0.999, 'weight_decay': 1e-3, 'safeguard_warmup': True}, 15),
390+
(PAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
388391
]
389392
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
390393
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 55
19+
assert len(get_supported_optimizers()) == 56

0 commit comments

Comments
 (0)