Skip to content

Commit 5113c54

Browse files
authored
Merge pull request #25 from kozistr/feature/diffgrad-optimizer
[Feature] Implement DiffGrad optimizer
2 parents 4735dce + cc9d942 commit 5113c54

File tree

19 files changed

+368
-179
lines changed

19 files changed

+368
-179
lines changed

.pylintrc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ disable=
5757
fixme,
5858
import-outside-toplevel,
5959
consider-using-enumerate,
60-
60+
duplicate-code,
61+
too-many-branches,
62+
too-many-statements,
6163

6264
# Enable the message, report, category or checker with the given id(s). You can
6365
# either give multiple identifier separated by comma (,) or put this option

README.rst

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pytorch-optimizer
44
| |workflow| |Documentation Status| |PyPI version| |PyPi download| |black|
55
66
| Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas.
7+
| Most of the implementations are based on the original paper, but I added some tweaks.
78
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
89
910
Documentation
@@ -53,6 +54,8 @@ Supported Optimizers
5354
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
5455
| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | `github <https://github.com/clovaai/AdamP>`__ | `https://arxiv.org/abs/2006.08217 <https://arxiv.org/abs/2006.08217>`__ |
5556
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
57+
| diffGrad | *An Optimization Method for Convolutional Neural Networks* | `github <https://github.com/shivram1987/diffGrad>`__ | `https://arxiv.org/abs/1909.11015v3 <https://arxiv.org/abs/1909.11015v3>`__ |
58+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
5659
| MADGRAD | *A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic* | `github <https://github.com/facebookresearch/madgrad>`__ | `https://arxiv.org/abs/2101.11075 <https://arxiv.org/abs/2101.11075>`__ |
5760
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
5861
| RAdam | *On the Variance of the Adaptive Learning Rate and Beyond* | `github <https://github.com/LiyuanLucasLiu/RAdam>`__ | `https://arxiv.org/abs/1908.03265 <https://arxiv.org/abs/1908.03265>`__ |
@@ -70,42 +73,51 @@ of the ideas are applied in ``Ranger21`` optimizer.
7073

7174
Also, most of the captures are taken from ``Ranger21`` paper.
7275

73-
Adaptive Gradient Clipping (AGC)
74-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
76+
+------------------------------------------+-------------------------------------+--------------------------------------------+
77+
| `Adaptive Gradient Clipping`_ | `Gradient Centralization`_ | `Softplus Transformation`_ |
78+
+------------------------------------------+-------------------------------------+--------------------------------------------+
79+
| `Gradient Normalization`_ | `Norm Loss`_ | `Positive-Negative Momentum`_ |
80+
+------------------------------------------+-------------------------------------+--------------------------------------------+
81+
| `Linear learning rate warmup`_ | `Stable weight decay`_ | `Explore-exploit learning rate schedule`_ |
82+
+------------------------------------------+-------------------------------------+--------------------------------------------+
83+
| `Lookahead`_ | `Chebyshev learning rate schedule`_ | `(Adaptive) Sharpness-Aware Minimization`_ |
84+
+------------------------------------------+-------------------------------------+--------------------------------------------+
85+
| `On the Convergence of Adam and Beyond`_ | | |
86+
+------------------------------------------+-------------------------------------+--------------------------------------------+
87+
88+
Adaptive Gradient Clipping
89+
--------------------------
7590

7691
| This idea originally proposed in ``NFNet (Normalized-Free Network)`` paper.
77-
| AGC (Adaptive Gradient Clipping) clips gradients based on the ``unit-wise ratio of gradient norms to parameter norms``.
92+
| ``AGC (Adaptive Gradient Clipping)`` clips gradients based on the ``unit-wise ratio of gradient norms to parameter norms``.
7893
79-
- code :
80-
`github <https://github.com/deepmind/deepmind-research/tree/master/nfnets>`__
94+
- code : `github <https://github.com/deepmind/deepmind-research/tree/master/nfnets>`__
8195
- paper : `arXiv <https://arxiv.org/abs/2102.06171>`__
8296

83-
Gradient Centralization (GC)
84-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
97+
Gradient Centralization
98+
-----------------------
8599

86100
+-----------------------------------------------------------------------------------------------------------------+
87101
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/gradient_centralization.png |
88102
+-----------------------------------------------------------------------------------------------------------------+
89103

90-
Gradient Centralization (GC) operates directly on gradients by
91-
centralizing the gradient to have zero mean.
104+
``Gradient Centralization (GC)`` operates directly on gradients by centralizing the gradient to have zero mean.
92105

93-
- code :
94-
`github <https://github.com/Yonghongwei/Gradient-Centralization>`__
106+
- code : `github <https://github.com/Yonghongwei/Gradient-Centralization>`__
95107
- paper : `arXiv <https://arxiv.org/abs/2004.01461>`__
96108

97109
Softplus Transformation
98-
~~~~~~~~~~~~~~~~~~~~~~~
110+
-----------------------
99111

100112
By running the final variance denom through the softplus function, it lifts extremely tiny values to keep them viable.
101113

102114
- paper : `arXiv <https://arxiv.org/abs/1908.00700>`__
103115

104116
Gradient Normalization
105-
~~~~~~~~~~~~~~~~~~~~~~
117+
----------------------
106118

107119
Norm Loss
108-
~~~~~~~~~
120+
---------
109121

110122
+---------------------------------------------------------------------------------------------------+
111123
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/norm_loss.png |
@@ -114,7 +126,7 @@ Norm Loss
114126
- paper : `arXiv <https://arxiv.org/abs/2103.06583>`__
115127

116128
Positive-Negative Momentum
117-
~~~~~~~~~~~~~~~~~~~~~~~~~~
129+
--------------------------
118130

119131
+--------------------------------------------------------------------------------------------------------------------+
120132
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/positive_negative_momentum.png |
@@ -123,8 +135,8 @@ Positive-Negative Momentum
123135
- code : `github <https://github.com/zeke-xie/Positive-Negative-Momentum>`__
124136
- paper : `arXiv <https://arxiv.org/abs/2103.17182>`__
125137

126-
Linear learning-rate warm-up
127-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
138+
Linear learning rate warmup
139+
---------------------------
128140

129141
+----------------------------------------------------------------------------------------------------------+
130142
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/linear_lr_warmup.png |
@@ -133,7 +145,7 @@ Linear learning-rate warm-up
133145
- paper : `arXiv <https://arxiv.org/abs/1910.04209>`__
134146

135147
Stable weight decay
136-
~~~~~~~~~~~~~~~~~~~
148+
-------------------
137149

138150
+-------------------------------------------------------------------------------------------------------------+
139151
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/stable_weight_decay.png |
@@ -142,8 +154,8 @@ Stable weight decay
142154
- code : `github <https://github.com/zeke-xie/stable-weight-decay-regularization>`__
143155
- paper : `arXiv <https://arxiv.org/abs/2011.11152>`__
144156

145-
Explore-exploit learning-rate schedule
146-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
157+
Explore-exploit learning rate schedule
158+
--------------------------------------
147159

148160
+---------------------------------------------------------------------------------------------------------------------+
149161
| .. image:: https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/assets/explore_exploit_lr_schedule.png |
@@ -153,7 +165,7 @@ Explore-exploit learning-rate schedule
153165
- paper : `arXiv <https://arxiv.org/abs/2003.03977>`__
154166

155167
Lookahead
156-
~~~~~~~~~
168+
---------
157169

158170
| ``k`` steps forward, 1 step back. ``Lookahead`` consisting of keeping an exponential moving average of the weights that is
159171
| updated and substituted to the current weights every ``k_{lookahead}`` steps (5 by default).
@@ -162,14 +174,14 @@ Lookahead
162174
- paper : `arXiv <https://arxiv.org/abs/1907.08610v2>`__
163175

164176
Chebyshev learning rate schedule
165-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177+
--------------------------------
166178

167179
Acceleration via Fractal Learning Rate Schedules
168180

169181
- paper : `arXiv <https://arxiv.org/abs/2103.01338v1>`__
170182

171-
(Adaptive) Sharpness-Aware Minimization (A/SAM)
172-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
183+
(Adaptive) Sharpness-Aware Minimization
184+
---------------------------------------
173185

174186
| Sharpness-Aware Minimization (SAM) simultaneously minimizes loss value and loss sharpness.
175187
| In particular, it seeks parameters that lie in neighborhoods having uniformly low loss.
@@ -178,6 +190,11 @@ Acceleration via Fractal Learning Rate Schedules
178190
- ASAM paper : `paper <https://arxiv.org/abs/2102.11600>`__
179191
- A/SAM code : `github <https://github.com/davda54/sam>`__
180192

193+
On the Convergence of Adam and Beyond
194+
-------------------------------------
195+
196+
- paper : `paper <https://openreview.net/forum?id=ryQu7f-RZ>`__
197+
181198
Citations
182199
---------
183200

@@ -387,6 +404,32 @@ Adaptive Sharpness-Aware Minimization
387404
year={2021}
388405
}
389406

407+
diffGrad
408+
409+
::
410+
411+
@article{dubey2019diffgrad,
412+
title={diffgrad: An optimization method for convolutional neural networks},
413+
author={Dubey, Shiv Ram and Chakraborty, Soumendu and Roy, Swalpa Kumar and Mukherjee, Snehasis and Singh, Satish Kumar and Chaudhuri, Bidyut Baran},
414+
journal={IEEE transactions on neural networks and learning systems},
415+
volume={31},
416+
number={11},
417+
pages={4500--4511},
418+
year={2019},
419+
publisher={IEEE}
420+
}
421+
422+
On the Convergence of Adam and Beyond
423+
424+
::
425+
426+
@article{reddi2019convergence,
427+
title={On the convergence of adam and beyond},
428+
author={Reddi, Sashank J and Kale, Satyen and Kumar, Sanjiv},
429+
journal={arXiv preprint arXiv:1904.09237},
430+
year={2019}
431+
}
432+
390433
Author
391434
------
392435

pytorch_optimizer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
# pylint: disable=unused-import
12
from pytorch_optimizer.adabelief import AdaBelief
23
from pytorch_optimizer.adabound import AdaBound
34
from pytorch_optimizer.adahessian import AdaHessian
45
from pytorch_optimizer.adamp import AdamP
56
from pytorch_optimizer.agc import agc
67
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
8+
from pytorch_optimizer.diffgrad import DiffGrad
79
from pytorch_optimizer.gc import centralize_gradient
810
from pytorch_optimizer.lookahead import Lookahead
911
from pytorch_optimizer.madgrad import MADGRAD
@@ -13,4 +15,4 @@
1315
from pytorch_optimizer.sam import SAM
1416
from pytorch_optimizer.sgdp import SGDP
1517

16-
__VERSION__ = '0.0.7'
18+
__VERSION__ = '0.0.8'

pytorch_optimizer/adabelief.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,22 @@ def __init__(
4444
degenerated_to_sgd: bool = True,
4545
):
4646
"""AdaBelief optimizer
47-
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
47+
:param params: PARAMS. iterable of parameters to optimize
48+
or dicts defining parameter groups
4849
:param lr: float. learning rate
49-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
50-
:param eps: float. term added to the denominator to improve numerical stability
50+
:param betas: BETAS. coefficients used for computing running averages
51+
of gradient and the squared hessian trace
52+
:param eps: float. term added to the denominator
53+
to improve numerical stability
5154
:param weight_decay: float. weight decay (L2 penalty)
5255
:param n_sma_threshold: (recommended is 5)
5356
:param amsgrad: bool. whether to use the AMSBound variant
54-
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
57+
:param weight_decouple: bool. the optimizer uses decoupled weight decay
58+
as in AdamW
5559
:param fixed_decay: bool.
5660
:param rectify: bool. perform the rectified update similar to RAdam
57-
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
61+
:param degenerated_to_sgd: bool. perform SGD update
62+
when variance of gradient is high
5863
"""
5964
self.lr = lr
6065
self.betas = betas

pytorch_optimizer/adabound.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class AdaBound(Optimizer):
1717
"""
18-
Reference : https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py
18+
Reference : https://github.com/Luolc/AdaBound
1919
Example :
2020
from pytorch_optimizer import AdaBound
2121
...
@@ -43,14 +43,18 @@ def __init__(
4343
amsbound: bool = False,
4444
):
4545
"""AdaBound optimizer
46-
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
46+
:param params: PARAMS. iterable of parameters to optimize
47+
or dicts defining parameter groups
4748
:param lr: float. learning rate
4849
:param final_lr: float. final learning rate
49-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
50+
:param betas: BETAS. coefficients used for computing running averages
51+
of gradient and the squared hessian trace
5052
:param gamma: float. convergence speed of the bound functions
51-
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
53+
:param weight_decouple: bool. the optimizer uses decoupled weight decay
54+
as in AdamW
5255
:param fixed_decay: bool.
53-
:param eps: float. term added to the denominator to improve numerical stability
56+
:param eps: float. term added to the denominator
57+
to improve numerical stability
5458
:param weight_decay: float. weight decay (L2 penalty)
5559
:param amsbound: bool. whether to use the AMSBound variant
5660
"""
@@ -75,11 +79,11 @@ def __init__(
7579
self.base_lrs = [group['lr'] for group in self.param_groups]
7680

7781
def check_valid_parameters(self):
78-
if 0.0 > self.lr:
82+
if self.lr < 0.0:
7983
raise ValueError(f'Invalid learning rate : {self.lr}')
80-
if 0.0 > self.eps:
84+
if self.eps < 0.0:
8185
raise ValueError(f'Invalid eps : {self.eps}')
82-
if 0.0 > self.weight_decay:
86+
if self.weight_decay < 0.0:
8387
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
8488
if not 0.0 <= self.betas[0] < 1.0:
8589
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')

pytorch_optimizer/adahessian.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Dict, Iterable
2+
13
import torch
24
from torch.optim import Optimizer
35

@@ -12,7 +14,7 @@
1214

1315
class AdaHessian(Optimizer):
1416
"""
15-
Reference : https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
17+
Reference : https://github.com/davda54/ada-hessian
1618
Example :
1719
from pytorch_optimizer import AdaHessian
1820
...
@@ -40,15 +42,21 @@ def __init__(
4042
seed: int = 2147483647,
4143
):
4244
"""
43-
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
45+
:param params: PARAMS. iterable of parameters to optimize
46+
or dicts defining parameter groups
4447
:param lr: float. learning rate.
45-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
46-
:param eps: float. term added to the denominator to improve numerical stability
48+
:param betas: BETAS. coefficients used for computing running averages
49+
of gradient and the squared hessian trace
50+
:param eps: float. term added to the denominator
51+
to improve numerical stability
4752
:param weight_decay: float. weight decay (L2 penalty)
4853
:param hessian_power: float. exponent of the hessian trace
49-
:param update_each: int. compute the hessian trace approximation only after *this* number of steps
50-
:param n_samples: int. how many times to sample `z` for the approximation of the hessian trace
51-
:param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper.
54+
:param update_each: int. compute the hessian trace approximation
55+
only after *this* number of steps
56+
:param n_samples: int. how many times to sample `z`
57+
for the approximation of the hessian trace
58+
:param average_conv_kernel: bool. average out the hessian traces
59+
of convolutional kernels as in the paper.
5260
:param seed: int.
5361
"""
5462
self.lr = lr
@@ -63,8 +71,8 @@ def __init__(
6371

6472
self.check_valid_parameters()
6573

66-
# use a separate generator that deterministically generates the same `z`s across all GPUs
67-
# in case of distributed training
74+
# use a separate generator that deterministically generates
75+
# the same `z`s across all GPUs in case of distributed training
6876
self.generator: torch.Generator = torch.Generator().manual_seed(
6977
self.seed
7078
)
@@ -83,11 +91,11 @@ def __init__(
8391
self.state[p]['hessian_step'] = 0
8492

8593
def check_valid_parameters(self):
86-
if 0.0 > self.lr:
94+
if self.lr < 0.0:
8795
raise ValueError(f'Invalid learning rate : {self.lr}')
88-
if 0.0 > self.eps:
96+
if self.eps < 0.0:
8997
raise ValueError(f'Invalid eps : {self.eps}')
90-
if 0.0 > self.weight_decay:
98+
if self.weight_decay < 0.0:
9199
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
92100
if not 0.0 <= self.betas[0] < 1.0:
93101
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
@@ -96,7 +104,7 @@ def check_valid_parameters(self):
96104
if not 0.0 <= self.hessian_power < 1.0:
97105
raise ValueError(f'Invalid hessian_power : {self.hessian_power}')
98106

99-
def get_params(self):
107+
def get_params(self) -> Iterable[Dict]:
100108
"""Gets all parameters in all param_groups with gradients"""
101109
return (
102110
p
@@ -116,7 +124,9 @@ def zero_hessian(self):
116124

117125
@torch.no_grad()
118126
def set_hessian(self):
119-
"""Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter"""
127+
"""Computes the Hutchinson approximation of the hessian trace
128+
and accumulates it for each trainable parameter
129+
"""
120130
params = []
121131
for p in filter(
122132
lambda param: param.grad is not None, self.get_params()

0 commit comments

Comments
 (0)