Skip to content

Commit 9110bbd

Browse files
authored
Merge pull request #157 from kozistr/feature/adashift-optimizer
[Feature] Implement AdaShift optimizer
2 parents 7df3749 + 6f7451e commit 9110bbd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+223
-95
lines changed

README.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, 48 optimizers, 6 lr schedulers are supported!
19+
| Currently, 49 optimizers, 6 lr schedulers are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -207,7 +207,9 @@ You can check the supported optimizers & lr schedulers.
207207
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
208208
| Chebyshev LR | *Acceleration via Fractal Learning Rate Schedules* | | `https://arxiv.org/abs/2103.01338 <https://arxiv.org/abs/2103.01338>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation>`__ |
209209
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
210-
| Untuned WU | *On the adequacy of untuned warmup for adaptive optimization* | | `https://arxiv.org/abs/1910.04209 <https://arxiv.org/abs/1910.04209>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2019arXiv191004209M/exportcitation>`__ |
210+
| Un-tuned WU | *On the adequacy of untuned warmup for adaptive optimization* | | `https://arxiv.org/abs/1910.04209 <https://arxiv.org/abs/1910.04209>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2019arXiv191004209M/exportcitation>`__ |
211+
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
212+
| AdaShift | *Decorrelation and Convergence of Adaptive Learning Rate Methods* | `github <https://github.com/MichaelKonobeev/adashift>`__ | `https://arxiv.org/abs/1810.00143v4 <https://arxiv.org/abs/1810.00143v4>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2018arXiv181000143Z/exportcitation>`__ |
211213
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
212214

213215
Useful Resources

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,3 +440,11 @@ AvaGrad
440440

441441
.. autoclass:: pytorch_optimizer.AvaGrad
442442
:members:
443+
444+
.. _AdaShift:
445+
446+
AdaShift
447+
--------
448+
449+
.. autoclass:: pytorch_optimizer.AdaShift
450+
:members:

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytorch_optimizer.optimizer.adan import Adan
2929
from pytorch_optimizer.optimizer.adanorm import AdaNorm
3030
from pytorch_optimizer.optimizer.adapnm import AdaPNM
31+
from pytorch_optimizer.optimizer.adashift import AdaShift
3132
from pytorch_optimizer.optimizer.adasmooth import AdaSmooth
3233
from pytorch_optimizer.optimizer.agc import agc
3334
from pytorch_optimizer.optimizer.aggmo import AggMo
@@ -141,6 +142,7 @@
141142
AdaSmooth,
142143
SRMM,
143144
AvaGrad,
145+
AdaShift,
144146
]
145147
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
146148

pytorch_optimizer/base/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def validate_range(x: float, name: str, low: float, high: float, range_type: str
127127
raise ValueError(f'[-] {name} must be in the range ({low}, {high})')
128128

129129
@staticmethod
130-
def validate_negative(x: float, name: str):
130+
def validate_non_negative(x: float, name: str):
131131
if x < 0.0:
132132
raise ValueError(f'[-] {name} must be non-negative')
133133

pytorch_optimizer/optimizer/a2grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def __init__(
3030
variant: str = 'uni',
3131
):
3232
self.validate_learning_rate(lr)
33-
self.validate_negative(lips, 'lips')
34-
self.validate_negative(rho, 'rho')
33+
self.validate_non_negative(lips, 'lips')
34+
self.validate_non_negative(rho, 'rho')
3535
self.validate_a2grad_variant(variant)
3636

3737
self.variant = variant

pytorch_optimizer/optimizer/adabelief.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def __init__(
4646
):
4747
self.validate_learning_rate(lr)
4848
self.validate_betas(betas)
49-
self.validate_negative(weight_decay, 'weight_decay')
50-
self.validate_negative(eps, 'eps')
49+
self.validate_non_negative(weight_decay, 'weight_decay')
50+
self.validate_non_negative(eps, 'eps')
5151

5252
self.n_sma_threshold = n_sma_threshold
5353
self.degenerated_to_sgd = degenerated_to_sgd

pytorch_optimizer/optimizer/adabound.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def __init__(
4141
):
4242
self.validate_learning_rate(lr)
4343
self.validate_betas(betas)
44-
self.validate_negative(weight_decay, 'weight_decay')
45-
self.validate_negative(eps, 'eps')
44+
self.validate_non_negative(weight_decay, 'weight_decay')
45+
self.validate_non_negative(eps, 'eps')
4646

4747
defaults: DEFAULTS = {
4848
'lr': lr,

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def __init__(
4848
):
4949
self.validate_learning_rate(lr)
5050
self.validate_betas(betas)
51-
self.validate_negative(weight_decay, 'weight_decay')
52-
self.validate_negative(eps1, 'eps1')
53-
self.validate_negative(eps2, 'eps2')
51+
self.validate_non_negative(weight_decay, 'weight_decay')
52+
self.validate_non_negative(eps1, 'eps1')
53+
self.validate_non_negative(eps2, 'eps2')
5454

5555
self.decay_rate = decay_rate
5656
self.clip_threshold = clip_threshold

pytorch_optimizer/optimizer/adai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def __init__(
3939
):
4040
self.validate_learning_rate(lr)
4141
self.validate_betas(betas)
42-
self.validate_negative(weight_decay, 'weight_decay')
43-
self.validate_negative(eps, 'eps')
42+
self.validate_non_negative(weight_decay, 'weight_decay')
43+
self.validate_non_negative(eps, 'eps')
4444

4545
self.use_gc = use_gc
4646

pytorch_optimizer/optimizer/adamax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def __init__(
3636
):
3737
self.validate_learning_rate(lr)
3838
self.validate_betas(betas)
39-
self.validate_negative(weight_decay, 'weight_decay')
40-
self.validate_negative(eps, 'eps')
39+
self.validate_non_negative(weight_decay, 'weight_decay')
40+
self.validate_non_negative(eps, 'eps')
4141

4242
defaults: DEFAULTS = {
4343
'lr': lr,

0 commit comments

Comments
 (0)