Skip to content

Commit 4dbfc23

Browse files
authored
Merge pull request #160 from kozistr/feature/adadelta-optimizer
[Feature] Implement AdaDelta optimizer
2 parents 2e97e5f + 4cc9bd7 commit 4dbfc23

File tree

10 files changed

+131
-21
lines changed

10 files changed

+131
-21
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,9 @@ jobs:
2323
uses: actions/setup-python@v4
2424
with:
2525
python-version: ${{ matrix.python-version }}
26-
- name: Cache pip
27-
uses: actions/cache@v3
28-
with:
29-
path: ~/.cache/pip
30-
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
31-
restore-keys: ${{ runner.os }}-pip-
26+
cache: 'pip'
3227
- name: Install dependencies
33-
run: pip install -r requirements-dev.txt
28+
run: pip --disable-pip-version-check install --no-compile -r requirements-dev.txt
3429
- name: Check lint
3530
run: make check
3631
- name: Check test

.github/workflows/publish.yml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,11 @@ jobs:
3232
uses: actions/setup-python@v4
3333
with:
3434
python-version: 3.11
35-
- name: Cache pip
36-
uses: actions/cache@v3
37-
with:
38-
path: ~/.cache/pip
39-
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
40-
restore-keys: ${{ runner.os }}-pip-
35+
cache: 'pip'
4136
- name: Install dependencies
4237
run: |
43-
python3 -m pip install poetry
44-
python3 -m pip install -r requirements.txt
38+
pip install poetry
39+
pip install -r requirements.txt
4540
- name: Publish package to PyPI
4641
env:
4742
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, 49 optimizers, 6 lr schedulers are supported!
19+
| Currently, 50 optimizers, 6 lr schedulers are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -211,6 +211,8 @@ You can check the supported optimizers & lr schedulers.
211211
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
212212
| AdaShift | *Decorrelation and Convergence of Adaptive Learning Rate Methods* | `github <https://github.com/MichaelKonobeev/adashift>`__ | `https://arxiv.org/abs/1810.00143v4 <https://arxiv.org/abs/1810.00143v4>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2018arXiv181000143Z/exportcitation>`__ |
213213
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
214+
| AdaDelta | *An Adaptive Learning Rate Method* | | `https://arxiv.org/abs/1212.5701v1 <https://arxiv.org/abs/1212.5701v1>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2012arXiv1212.5701Z/exportcitation>`__ |
215+
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
214216

215217
Useful Resources
216218
----------------

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,11 @@ AdaShift
448448

449449
.. autoclass:: pytorch_optimizer.AdaShift
450450
:members:
451+
452+
.. _AdaDelta:
453+
454+
AdaDelta
455+
--------
456+
457+
.. autoclass:: pytorch_optimizer.AdaDelta
458+
:members:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.8.0"
3+
version = "2.9.0"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -9,7 +9,7 @@ readme = "README.rst"
99
homepage = "https://github.com/kozistr/pytorch_optimizer"
1010
repository = "https://github.com/kozistr/pytorch_optimizer"
1111
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
12-
keywords = ["pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "Adai", "AdamP", "AdamS", "Adan", "AggMo", "AliG", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PID", "PNM", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SM3", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad"]
12+
keywords = ["pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "Adai", "AdamP", "AdamS", "Adan", "AggMo", "AliG", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PID", "PNM", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SM3", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad"]
1313
classifiers = [
1414
"License :: OSI Approved :: Apache Software License",
1515
"Development Status :: 5 - Production/Stable",

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_optimizer.optimizer.a2grad import A2Grad
2020
from pytorch_optimizer.optimizer.adabelief import AdaBelief
2121
from pytorch_optimizer.optimizer.adabound import AdaBound
22+
from pytorch_optimizer.optimizer.adadelta import AdaDelta
2223
from pytorch_optimizer.optimizer.adafactor import AdaFactor
2324
from pytorch_optimizer.optimizer.adai import Adai
2425
from pytorch_optimizer.optimizer.adamax import AdaMax
@@ -143,6 +144,7 @@
143144
SRMM,
144145
AvaGrad,
145146
AdaShift,
147+
AdaDelta,
146148
]
147149
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
148150

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base.exception import NoSparseGradientError
5+
from pytorch_optimizer.base.optimizer import BaseOptimizer
6+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
8+
9+
class AdaDelta(Optimizer, BaseOptimizer):
10+
r"""An Adaptive Learning Rate Method.
11+
12+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
13+
:param lr: float. learning rate.
14+
:param rho: float. coefficient used for computing a running average of squared gradients.
15+
:param weight_decay: float. weight decay (L2 penalty).
16+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
17+
:param fixed_decay: bool. fix weight decay.
18+
:param eps: float. term added to the denominator to improve numerical stability.
19+
"""
20+
21+
def __init__(
22+
self,
23+
params: PARAMETERS,
24+
lr: float = 1.0,
25+
rho: float = 0.9,
26+
weight_decay: float = 0.0,
27+
weight_decouple: bool = False,
28+
fixed_decay: bool = False,
29+
eps: float = 1e-6,
30+
):
31+
self.validate_learning_rate(lr)
32+
self.validate_range(rho, 'rho', 0.0, 1.0)
33+
self.validate_non_negative(weight_decay, 'weight_decay')
34+
self.validate_non_negative(eps, 'eps')
35+
36+
defaults: DEFAULTS = {
37+
'lr': lr,
38+
'rho': rho,
39+
'weight_decay': weight_decay,
40+
'weight_decouple': weight_decouple,
41+
'fixed_decay': fixed_decay,
42+
'eps': eps,
43+
}
44+
super().__init__(params, defaults)
45+
46+
def __str__(self) -> str:
47+
return 'AdaDelta'
48+
49+
@torch.no_grad()
50+
def reset(self):
51+
for group in self.param_groups:
52+
group['step'] = 0
53+
for p in group['params']:
54+
state = self.state[p]
55+
56+
state['square_avg'] = torch.zeros_like(p)
57+
state['acc_delta'] = torch.zeros_like(p)
58+
59+
@torch.no_grad()
60+
def step(self, closure: CLOSURE = None) -> LOSS:
61+
loss: LOSS = None
62+
if closure is not None:
63+
with torch.enable_grad():
64+
loss = closure()
65+
66+
for group in self.param_groups:
67+
if 'step' in group:
68+
group['step'] += 1
69+
else:
70+
group['step'] = 1
71+
72+
rho: float = group['rho']
73+
74+
for p in group['params']:
75+
if p.grad is None:
76+
continue
77+
78+
grad = p.grad
79+
if grad.is_sparse:
80+
raise NoSparseGradientError(str(self))
81+
82+
state = self.state[p]
83+
84+
if len(state) == 0:
85+
state['square_avg'] = torch.zeros_like(p)
86+
state['acc_delta'] = torch.zeros_like(p)
87+
88+
self.apply_weight_decay(
89+
p=p,
90+
grad=grad,
91+
lr=group['lr'],
92+
weight_decay=group['weight_decay'],
93+
weight_decouple=group['weight_decouple'],
94+
fixed_decay=group['fixed_decay'],
95+
)
96+
97+
square_avg, acc_delta = state['square_avg'], state['acc_delta']
98+
square_avg.mul_(rho).addcmul_(grad, grad, value=1.0 - rho)
99+
100+
std = square_avg.add(group['eps']).sqrt_()
101+
delta = acc_delta.add(group['eps']).sqrt_().div_(std).mul_(grad)
102+
103+
acc_delta.mul_(rho).addcmul_(delta, delta, value=1.0 - rho)
104+
p.add_(delta, alpha=-group['lr'])
105+
106+
return loss

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
327327

328328
# it's not Adam Debias
329329
d_lr: float = self.apply_adam_debias(
330-
group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
330+
not group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
331331
)
332332

333333
sk_l1 = torch.tensor([0.0], device=device)

tests/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
AccSGD,
1818
AdaBelief,
1919
AdaBound,
20+
AdaDelta,
2021
AdaFactor,
2122
Adai,
2223
AdaMax,
@@ -312,7 +313,7 @@
312313
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
313314
(DAdaptAdaGrad, {'lr': 3e0, 'weight_decay': 1e-3}, 30),
314315
(DAdaptAdaGrad, {'lr': 5e0, 'weight_decay': 1e-3, 'momentum': 0.1}, 20),
315-
(DAdaptAdam, {'lr': 5e4, 'weight_decay': 1e-1}, 10),
316+
(DAdaptAdam, {'lr': 5e4, 'weight_decay': 1e-3}, 5),
316317
(DAdaptSGD, {'lr': 2e0, 'weight_decay': 1e-3}, 25),
317318
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3}, 20),
318319
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 20),
@@ -363,6 +364,7 @@
363364
(SRMM, {'lr': 5e-1}, 5),
364365
(AvaGrad, {'lr': 1e1}, 5),
365366
(AdaShift, {'lr': 1e0, 'keep_num': 1}, 5),
367+
(AdaDelta, {'lr': 5e1}, 5),
366368
]
367369
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
368370
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 49
19+
assert len(get_supported_optimizers()) == 50

0 commit comments

Comments
 (0)