Skip to content

Commit e7834d5

Browse files
authored
Merge pull request #177 from kozistr/refactor/sophia-optimizer
[Refactor] AdaHessian, SophiaH optimizers
2 parents 5c18933 + 6008a00 commit e7834d5

File tree

16 files changed

+266
-169
lines changed

16 files changed

+266
-169
lines changed

README.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, 51 optimizers, 6 lr schedulers are supported!
19+
| Currently, 54 optimizers, 6 lr schedulers are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -213,7 +213,13 @@ You can check the supported optimizers & lr schedulers.
213213
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
214214
| AdaDelta | *An Adaptive Learning Rate Method* | | `https://arxiv.org/abs/1212.5701v1 <https://arxiv.org/abs/1212.5701v1>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2012arXiv1212.5701Z/exportcitation>`__ |
215215
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
216-
| Amos | * An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale* | `github <https://github.com/google-research/jestimator>`__ | `https://arxiv.org/abs/2210.11693 <https://arxiv.org/abs/2210.11693>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2022arXiv221011693T/exportcitation>`__ |
216+
| Amos | *An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale* | `github <https://github.com/google-research/jestimator>`__ | `https://arxiv.org/abs/2210.11693 <https://arxiv.org/abs/2210.11693>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2022arXiv221011693T/exportcitation>`__ |
217+
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
218+
| SignSGD | *Compressed Optimisation for Non-Convex Problems* | `github <https://github.com/jxbz/signSGD>`__ | `https://arxiv.org/abs/1802.04434 <https://arxiv.org/abs/1802.04434>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2018arXiv180204434B/exportcitation>`__ |
219+
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
220+
| AdaHessian | *An Adaptive Second Order Optimizer for Machine Learning* | `github <https://github.com/amirgholami/adahessian>`__ | `https://arxiv.org/abs/2006.00719 <https://arxiv.org/abs/2006.00719>`__ | `cite <https://github.com/amirgholami/adahessian#citation>`__ |
221+
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
222+
| Sophia | *A Scalable Stochastic Second-order Optimizer for Language Model Pre-training* | `github <https://github.com/Liuhong99/Sophia>`__ | `https://arxiv.org/abs/2305.14342 <https://arxiv.org/abs/2305.14342>`__ | `cite <https://github.com/Liuhong99/Sophia>`__ |
217223
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
218224

219225
Useful Resources

docs/changelogs/v.2.10.0.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
## Change Log
2+
3+
### Feature
4+
5+
* Implement Amos optimizer (#174)
6+
* [An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale](https://arxiv.org/abs/2210.11693)
7+
* Implement SignSGD optimizer (#176) (thanks to @i404788)
8+
* [Compressed Optimisation for Non-Convex Problems](https://arxiv.org/abs/1802.04434)
9+
* Implement AdaHessian optimizer (#176) (thanks to @i404788)
10+
* [An Adaptive Second Order Optimizer for Machine Learning](https://arxiv.org/abs/2006.00719)
11+
* Implement SophiaH optimizer (#173, #176) (thanks to @i404788)
12+
* [A Scalable Stochastic Second-order Optimizer for Language Model Pre-training](https://arxiv.org/abs/2305.14342)
13+
* Implement re-usable functions to compute hessian in `BaseOptimizer` (#176, #177) (thanks to @i404788)
14+
* two types of distribution are supported (`gaussian`, `rademacher`).
15+
* Support `AdamD` variant for AdaHessian optimizer (#177)
16+
17+
### Diff
18+
19+
[2.9.1...2.10.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.9.1...v2.10.0)

docs/optimizer_api.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,27 @@ Amos
464464

465465
.. autoclass:: pytorch_optimizer.Amos
466466
:members:
467+
468+
.. _SignSGD:
469+
470+
SignSGD
471+
-------
472+
473+
.. autoclass:: pytorch_optimizer.SignSGD
474+
:members:
475+
476+
.. _AdaHessian:
477+
478+
AdaHessian
479+
----------
480+
481+
.. autoclass:: pytorch_optimizer.AdaHessian
482+
:members:
483+
484+
.. _SophiaH:
485+
486+
SophiaH
487+
-------
488+
489+
.. autoclass:: pytorch_optimizer.SophiaH
490+
:members:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.9.1"
3+
version = "2.10.0"
44
description = "optimizer & lr scheduler collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -9,7 +9,7 @@ readme = "README.rst"
99
homepage = "https://github.com/kozistr/pytorch_optimizer"
1010
repository = "https://github.com/kozistr/pytorch_optimizer"
1111
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
12-
keywords = ["pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "Adai", "AdamP", "AdamS", "Adan", "AggMo", "AliG", "Amos", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PID", "PNM", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SM3", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad"]
12+
keywords = ["pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP", "AdamS", "Adan", "AggMo", "AliG", "Amos", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PID", "PNM", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad"]
1313
classifiers = [
1414
"License :: OSI Approved :: Apache Software License",
1515
"Development Status :: 5 - Production/Stable",

pytorch_optimizer/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytorch_optimizer.optimizer.adabound import AdaBound
2222
from pytorch_optimizer.optimizer.adadelta import AdaDelta
2323
from pytorch_optimizer.optimizer.adafactor import AdaFactor
24+
from pytorch_optimizer.optimizer.adahessian import AdaHessian
2425
from pytorch_optimizer.optimizer.adai import Adai
2526
from pytorch_optimizer.optimizer.adamax import AdaMax
2627
from pytorch_optimizer.optimizer.adamod import AdaMod
@@ -81,10 +82,9 @@
8182
power_iteration,
8283
)
8384
from pytorch_optimizer.optimizer.sm3 import SM3
85+
from pytorch_optimizer.optimizer.sophia import SophiaH
8486
from pytorch_optimizer.optimizer.srmm import SRMM
8587
from pytorch_optimizer.optimizer.swats import SWATS
86-
from pytorch_optimizer.optimizer.adahessian import AdaHessian
87-
from pytorch_optimizer.optimizer.sophiah import SophiaH
8888
from pytorch_optimizer.optimizer.utils import (
8989
clip_grad_norm,
9090
disable_running_stats,
@@ -151,7 +151,7 @@
151151
Amos,
152152
AdaHessian,
153153
SophiaH,
154-
SignSGD
154+
SignSGD,
155155
]
156156
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
157157

pytorch_optimizer/base/optimizer.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,78 +4,80 @@
44

55
import torch
66

7-
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError, NoSparseGradientError
8-
from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G
7+
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
8+
from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G, PARAMETERS, STATE
99

1010

1111
class BaseOptimizer(ABC):
1212
r"""Base optimizer class."""
1313

14+
@staticmethod
1415
@torch.no_grad()
15-
def set_hessian(self, hessian):
16-
"""
17-
Helper function to set hessian state from external source
18-
Generally useful when using functorch as a base
16+
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]):
17+
r"""Set hessian to state from external source. Generally useful when using functorch as a base.
1918
2019
Example usage:
2120
```
2221
# Hutchinsons Estimator using HVP
2322
noise = tree_map(lambda v: torch.randn_like(v), params)
2423
loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
25-
hessian_diag_est = tree_map(lambda a, b: a*b, hvp_est, noise)
24+
hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
2625
2726
optimizer.set_hessian(hessian_diag_est)
2827
# OR
2928
optimizer.step(hessian=hessian_diag_est)
3029
````
31-
3230
"""
33-
i = 0
34-
for group in self.param_groups:
31+
i: int = 0
32+
for group in param_groups:
3533
for p in group['params']:
36-
assert p.shape == hessian[i].shape
37-
self.state[p]['hessian'] = hessian[i]
34+
if p.size() != hessian[i].size():
35+
raise ValueError(
36+
f'[-] the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
37+
)
38+
39+
state[p]['hessian'] = hessian[i]
3840
i += 1
3941

42+
@staticmethod
4043
@torch.no_grad()
41-
def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0, distribution: HUTCHINSON_G = 'gaussian'):
42-
"""
43-
Hutchinsons approximate hessian, added to the state under key 'hessian'
44-
"""
45-
if distribution not in ['gaussian', 'rademacher']:
46-
raise NotImplementedError(f"Hessian with distribution {distribution} is not implemented")
44+
def compute_hutchinson_hessian(
45+
param_groups: PARAMETERS,
46+
state: STATE,
47+
num_samples: int = 1,
48+
pre_zero: bool = True,
49+
alpha: float = 1.0,
50+
distribution: HUTCHINSON_G = 'gaussian',
51+
):
52+
r"""Hutchinson's approximate hessian, added to the state under key `hessian`."""
53+
if distribution not in ('gaussian', 'rademacher'):
54+
raise NotImplementedError(f'[-] Hessian with distribution {distribution} is not implemented.')
4755

4856
params = []
49-
for group in self.param_groups:
57+
for group in param_groups:
5058
for p in group['params']:
51-
if p.requires_grad and p.grad is not None:
52-
if p.grad.is_sparse:
53-
raise NoSparseGradientError(str(self))
54-
# Initialize Hessian state
55-
if 'hessian' in self.state[p]:
56-
if pre_zero:
57-
self.state[p]['hessian'].zero_()
58-
else:
59-
self.state[p]['hessian'] = torch.zeros_like(p.data)
59+
if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
60+
if 'hessian' not in state[p]:
61+
state[p]['hessian'] = torch.zeros_like(p)
62+
elif pre_zero:
63+
state[p]['hessian'].zero_()
64+
6065
params.append(p)
6166

6267
if len(params) == 0:
6368
return
6469

6570
grads = [p.grad for p in params]
6671

67-
for i in range(nsamples):
68-
if distribution == 'gaussian':
69-
# Gaussian N(0,Id)
70-
zs = [torch.randn(p.size(), device=p.device) for p in params]
71-
elif distribution == 'rademacher':
72-
# Rademacher distribution {-1.0, 1.0}
73-
zs = [torch.randint(0, 2, p.size(), dtype=p.dtype, device=p.device) * 2.0 - 1.0 for p in params]
72+
for i in range(num_samples):
73+
if distribution == 'rademacher':
74+
zs = [torch.randint_like(p, 0, 1) * 2.0 - 1.0 for p in params]
75+
else:
76+
zs = [torch.randn_like(p) for p in params]
7477

75-
h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < nsamples - 1)
78+
h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1)
7679
for h_z, z, p in zip(h_zs, zs, params):
77-
# approximate the expected values of z*(H@z)
78-
self.state[p]['hessian'].add_(h_z * z, alpha=(1/nsamples) * alpha)
80+
state[p]['hessian'].add_(h_z * z, alpha=(1 / num_samples) * alpha)
7981

8082
@staticmethod
8183
def apply_weight_decay(

pytorch_optimizer/base/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union, Literal
1+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Type, Union
22

33
import torch
44
from torch.optim import Optimizer
@@ -13,4 +13,4 @@
1313
OPTIMIZER = Type[Optimizer]
1414
SCHEDULER = Type[_LRScheduler]
1515

16-
HUTCHINSON_G = Literal['gaussian', 'rademacher']
16+
HUTCHINSON_G = Literal['gaussian', 'rademacher']

0 commit comments

Comments
 (0)