Skip to content

Commit 25618d7

Browse files
authored
Merge pull request #203 from kozistr/feature/dadaptlion-optimizer
[Feature] Implement DAdaptLion optimizer
2 parents 0a23375 + a20c12c commit 25618d7

File tree

13 files changed

+200
-27
lines changed

13 files changed

+200
-27
lines changed

README.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, 59 optimizers, 10 lr schedulers, and 13 loss functions are supported!
19+
| Currently, **60 optimizers**, **10 lr schedulers**, and **13 loss functions** are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -31,20 +31,20 @@ So, please double-check the license before using it at your work.
3131
Installation
3232
~~~~~~~~~~~~
3333

34-
::
34+
.. code-block:: bash
3535
3636
$ pip3 install -U pytorch-optimizer
3737
3838
If there's a version issue when installing the package, try with `--no-deps` option.
3939

40-
::
40+
.. code-block:: bash
4141
4242
$ pip3 install -U --no-deps pytorch-optimizer
4343
4444
Simple Usage
4545
~~~~~~~~~~~~
4646

47-
::
47+
.. code-block:: python
4848
4949
from pytorch_optimizer import AdamP
5050
@@ -61,7 +61,7 @@ Simple Usage
6161
6262
Also, you can load the optimizer via `torch.hub`
6363

64-
::
64+
.. code-block:: python
6565
6666
import torch
6767
@@ -71,7 +71,7 @@ Also, you can load the optimizer via `torch.hub`
7171
7272
If you want to build the optimizer with parameters & configs, there's `create_optimizer()` API.
7373

74-
::
74+
.. code-block:: python
7575
7676
from pytorch_optimizer import create_optimizer
7777
@@ -89,7 +89,7 @@ Supported Optimizers
8989

9090
You can check the supported optimizers with below code.
9191

92-
::
92+
.. code-block:: python
9393
9494
from pytorch_optimizer import get_supported_optimizers
9595
@@ -230,7 +230,7 @@ Supported LR Scheduler
230230

231231
You can check the supported learning rate schedulers with below code.
232232

233-
::
233+
.. code-block:: python
234234
235235
from pytorch_optimizer import get_supported_lr_schedulers
236236
@@ -249,7 +249,7 @@ Supported Loss Function
249249

250250
You can check the supported loss functions with below code.
251251

252-
::
252+
.. code-block:: python
253253
254254
from pytorch_optimizer import get_supported_loss_functions
255255

docs/changelogs/v2.11.2.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
## Change Log
22

3+
### Feature
4+
5+
* Implement DAdaptLion optimizer (#203)
6+
* [Lion with D-Adaptation](https://github.com/facebookresearch/dadaptation/blob/main/dadaptation/dadapt_lion.py)
7+
38
### Fix
49

510
* Fix Lookahead optimizer (#200, #201, #202)
611
* When using PyTorch Lightning which expects your optimiser to be a subclass of `Optimizer`.
12+
* Fix default `rectify` to `False` in `AdaBelief` optimizer (#203)
13+
14+
### Test
15+
16+
* Add `DynamicLossScaler` test case
17+
18+
### Docs
19+
20+
* Highlight the code blocks
21+
* Fix pepy badges
722

823
### Contributions
924

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ DAdaptAdan
225225
.. autoclass:: pytorch_optimizer.DAdaptAdan
226226
:members:
227227

228+
.. _DAdaptLion:
229+
230+
DAdaptLion
231+
----------
232+
233+
.. autoclass:: pytorch_optimizer.DAdaptLion
234+
:members:
235+
228236
.. _AdamS:
229237

230238
AdamS

pyproject.toml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.11.1"
3+
version = "2.11.2"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -13,10 +13,10 @@ keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
1515
"AdamS", "Adan", "AggMo", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan",
16-
"DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD",
17-
"MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
18-
"Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM",
19-
"SWATS", "Tiger", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
16+
"DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO", "Lookahead",
17+
"MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam",
18+
"Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH",
19+
"SRMM", "SWATS", "Tiger", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
2020
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
2121
]
2222
classifiers = [
@@ -122,7 +122,6 @@ testpaths = "tests"
122122
[tool.coverage.run]
123123
omit = [
124124
"./pytorch_optimizer/optimizer/gsam.py",
125-
"./pytorch_optimizer/optimizer/fp16.py",
126125
"./pytorch_optimizer/optimizer/rotograd.py",
127126
]
128127

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from pytorch_optimizer.optimizer.apollo import Apollo
4949
from pytorch_optimizer.optimizer.avagrad import AvaGrad
5050
from pytorch_optimizer.optimizer.came import CAME
51-
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptSGD
51+
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptLion, DAdaptSGD
5252
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
5353
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
5454
from pytorch_optimizer.optimizer.fromage import Fromage
@@ -171,6 +171,7 @@
171171
LOMO,
172172
Tiger,
173173
CAME,
174+
DAdaptLion,
174175
]
175176
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
176177

pytorch_optimizer/optimizer/adabelief.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
weight_decay: float = 0.0,
3636
weight_decouple: bool = True,
3737
fixed_decay: bool = False,
38-
rectify: bool = True,
38+
rectify: bool = False,
3939
n_sma_threshold: int = 5,
4040
degenerated_to_sgd: bool = True,
4141
ams_bound: bool = False,

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,132 @@ def step(self, closure: CLOSURE = None) -> LOSS:
699699
group['k'] += 1
700700

701701
return loss
702+
703+
704+
class DAdaptLion(Optimizer, BaseOptimizer):
705+
r"""Lion with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.
706+
707+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
708+
:param lr: float. learning rate.
709+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
710+
:param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
711+
:param weight_decay: float. weight decay (L2 penalty).
712+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
713+
:param fixed_decay: bool. fix weight decay.
714+
"""
715+
716+
def __init__(
717+
self,
718+
params: PARAMETERS,
719+
lr: float = 1.0,
720+
betas: BETAS = (0.9, 0.999),
721+
d0: float = 1e-6,
722+
weight_decay: float = 0.0,
723+
weight_decouple: bool = False,
724+
fixed_decay: bool = False,
725+
):
726+
self.validate_learning_rate(lr)
727+
self.validate_betas(betas)
728+
self.validate_non_negative(weight_decay, 'weight_decay')
729+
730+
defaults: DEFAULTS = {
731+
'lr': lr,
732+
'betas': betas,
733+
'd': d0,
734+
'weight_decay': weight_decay,
735+
'weight_decouple': weight_decouple,
736+
'fixed_decay': fixed_decay,
737+
'step': 0,
738+
}
739+
super().__init__(params, defaults)
740+
741+
def __str__(self) -> str:
742+
return 'DAdaptLion'
743+
744+
@torch.no_grad()
745+
def reset(self):
746+
for group in self.param_groups:
747+
group['step'] = 0
748+
for p in group['params']:
749+
if p.grad is None:
750+
continue
751+
752+
state = self.state[p]
753+
754+
state['exp_avg'] = torch.zeros_like(p)
755+
state['s'] = torch.zeros_like(p)
756+
757+
@torch.no_grad()
758+
def step(self, closure: CLOSURE = None) -> LOSS:
759+
loss: LOSS = None
760+
if closure is not None:
761+
with torch.enable_grad():
762+
loss = closure()
763+
764+
group = self.param_groups[0]
765+
device = group['params'][0].device
766+
767+
if 'numerator_weighted' not in group:
768+
group['numerator_weighted'] = torch.tensor([0.0], device=device)
769+
numerator_weighted = group['numerator_weighted']
770+
771+
sk_l1 = torch.tensor([0.0], device=device)
772+
numerator_accumulator = torch.tensor([0.0], device=device)
773+
774+
beta1, beta2 = group['betas']
775+
beta2_sq = math.sqrt(beta2)
776+
777+
d, lr = group['d'], group['lr']
778+
d_lr: float = d * lr
779+
780+
for group in self.param_groups:
781+
for p in group['params']:
782+
if p.grad is None:
783+
continue
784+
785+
grad = p.grad
786+
if grad.is_sparse:
787+
raise NoSparseGradientError(str(self))
788+
789+
state = self.state[p]
790+
if len(state) == 0:
791+
state['exp_avg'] = torch.zeros_like(p)
792+
state['s'] = torch.zeros_like(p)
793+
794+
self.apply_weight_decay(
795+
p=p,
796+
grad=grad,
797+
lr=d_lr,
798+
weight_decay=group['weight_decay'],
799+
weight_decouple=group['weight_decouple'],
800+
fixed_decay=group['fixed_decay'],
801+
)
802+
803+
exp_avg, s = state['exp_avg'], state['s']
804+
805+
update = exp_avg.clone().mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
806+
p.add_(update, alpha=-d_lr)
807+
808+
exp_avg.mul_(beta2).add_(grad, alpha=(1.0 - beta2) * d_lr)
809+
810+
numerator_accumulator.add_(torch.dot(update.flatten(), s.flatten()), alpha=d_lr)
811+
s.mul_(beta2_sq).add_(update, alpha=(1.0 - beta2_sq) * d_lr)
812+
813+
sk_l1.add_(s.abs().sum())
814+
815+
numerator_weighted.mul_(beta2_sq).add_(numerator_accumulator, alpha=1.0 - beta2_sq)
816+
817+
if sk_l1 == 0:
818+
return loss
819+
820+
if lr > 0.0:
821+
d_hat: float = (numerator_weighted / ((1.0 - beta2_sq) * sk_l1)).item()
822+
d = max(d, d_hat)
823+
824+
for group in self.param_groups:
825+
group['step'] += 1
826+
827+
group['numerator_weighted'] = numerator_weighted
828+
group['d'] = d
829+
830+
return loss

pytorch_optimizer/optimizer/fp16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def decrease_loss_scale(self):
9090
self.loss_scale = max(self.loss_scale, self.threshold)
9191

9292

93-
class SafeFP16Optimizer(Optimizer):
93+
class SafeFP16Optimizer(Optimizer): # pragma: no cover
9494
r"""Safe FP16 Optimizer.
9595
9696
:param optimizer: OPTIMIZER.

tests/constants.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
DAdaptAdaGrad,
4141
DAdaptAdam,
4242
DAdaptAdan,
43+
DAdaptLion,
4344
DAdaptSGD,
4445
DiffGrad,
4546
Fromage,
@@ -96,6 +97,7 @@
9697
'scalableshampoo',
9798
'dadaptadam',
9899
'dadaptadan',
100+
'dadaptlion',
99101
'adams',
100102
'adafactor',
101103
'novograd',
@@ -127,11 +129,11 @@
127129

128130
OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
129131
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
130-
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
131-
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 10),
132-
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 10),
133-
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'fixed_decay': True}, 10),
134-
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'rectify': False}, 10),
132+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
133+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 5),
134+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 5),
135+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'fixed_decay': True}, 5),
136+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'rectify': True}, 10),
135137
(AdaBound, {'lr': 1e0, 'gamma': 0.1, 'weight_decay': 1e-3}, 20),
136138
(AdaBound, {'lr': 1e0, 'gamma': 0.1, 'weight_decay': 1e-3, 'fixed_decay': True}, 20),
137139
(AdaBound, {'lr': 1e0, 'gamma': 0.1, 'weight_decay': 1e-3, 'weight_decouple': False}, 20),
@@ -329,6 +331,7 @@
329331
(DAdaptSGD, {'lr': 2e0, 'weight_decay': 1e-3}, 25),
330332
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3}, 20),
331333
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 20),
334+
(DAdaptLion, {'lr': 3e0, 'weight_decay': 1e-3}, 20),
332335
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
333336
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
334337
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),

tests/test_general_optimizer_parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_epsilon(optimizer_name):
2727
'shampoo',
2828
'scalableshampoo',
2929
'dadaptsgd',
30+
'dadaptlion',
3031
'adafactor',
3132
'lion',
3233
'a2grad',

0 commit comments

Comments
 (0)