Skip to content

Commit 19dcf2b

Browse files
authored
Merge pull request #130 from kozistr/feature/sm3-optimizer
[Feature] Implement SM3 Optimizer
2 parents ec2d693 + c8fdf41 commit 19dcf2b

File tree

14 files changed

+330
-139
lines changed

14 files changed

+330
-139
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ You can check the supported optimizers & lr schedulers.
138138
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
139139
| Ali-G | *Adaptive Learning Rates for Interpolation with Gradients* | `github <https://github.com/oval-group/ali-g>`__ | `https://arxiv.org/abs/1906.05661 <https://arxiv.org/abs/1906.05661>`__ |
140140
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
141+
| SM3 | *Memory-Efficient Adaptive Optimization* | `github <https://github.com/google-research/google-research/tree/master/sm3>`__ | `https://arxiv.org/abs/1901.11150 <https://arxiv.org/abs/1901.11150>`__ |
142+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
141143

142144
Useful Resources
143145
----------------
@@ -343,6 +345,8 @@ Citations
343345

344346
`Ali-G <https://github.com/oval-group/ali-g#adaptive-learning-rates-for-interpolation-with-gradients>`__
345347

348+
`SM3 <https://ui.adsabs.harvard.edu/abs/2019arXiv190111150A/exportcitation>`__
349+
346350
Citation
347351
--------
348352

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,11 @@ AliG
280280

281281
.. autoclass:: pytorch_optimizer.AliG
282282
:members:
283+
284+
.. _SM3:
285+
286+
SM3
287+
---
288+
289+
.. autoclass:: pytorch_optimizer.SM3
290+
:members:

poetry.lock

Lines changed: 92 additions & 112 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.5.2"
3+
version = "2.6.0"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -38,7 +38,7 @@ numpy = [
3838
{ version = "*", python = ">=3.8" },
3939
]
4040
torch = [
41-
{ version = ">=1.10,>=2.0", python = ">=3.8", source = "torch" },
41+
{ version = ">=1.10", python = ">=3.8", source = "torch" },
4242
{ version = "^1.10", python = ">=3.7,<3.8", source = "torch" },
4343
]
4444

@@ -48,8 +48,8 @@ isort = [
4848
{ version = "^5.12.0", python = ">=3.8"}
4949
]
5050
black = "^23.3.0"
51-
ruff = "^0.0.260"
52-
pytest = "^7.2.2"
51+
ruff = "^0.0.262"
52+
pytest = "^7.3.1"
5353
pytest-cov = "^4.0.0"
5454

5555
[[tool.poetry.source]]

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
merge_small_dims,
6565
power_iteration,
6666
)
67+
from pytorch_optimizer.optimizer.sm3 import SM3
6768
from pytorch_optimizer.optimizer.utils import (
6869
clip_grad_norm,
6970
disable_running_stats,
@@ -103,6 +104,7 @@
103104
NovoGrad,
104105
Lion,
105106
AliG,
107+
SM3,
106108
]
107109
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
108110

pytorch_optimizer/optimizer/sm3.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base.optimizer import BaseOptimizer
5+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
6+
7+
8+
class SM3(Optimizer, BaseOptimizer):
9+
r"""Memory-Efficient Adaptive Optimization.
10+
11+
Reference : https://github.com/Enealor/PyTorch-SM3/blob/master/src/SM3/SM3.py
12+
13+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
14+
:param lr: float. learning rate.
15+
:param momentum: float. coefficient used to scale prior updates before adding. This drastically increases
16+
memory usage if `momentum > 0.0`. This is ignored if the parameter's gradient is sparse.
17+
:param beta: float. coefficient used for exponential moving averages.
18+
"""
19+
20+
def __init__(
21+
self,
22+
params: PARAMETERS,
23+
lr: float = 1e-1,
24+
momentum: float = 0.0,
25+
beta: float = 0.0,
26+
eps: float = 1e-30,
27+
):
28+
self.lr = lr
29+
self.momentum = momentum
30+
self.beta = beta
31+
self.eps = eps
32+
33+
self.validate_parameters()
34+
35+
defaults: DEFAULTS = {'lr': lr, 'momentum': momentum, 'beta': beta}
36+
super().__init__(params, defaults)
37+
38+
def validate_parameters(self):
39+
self.validate_learning_rate(self.lr)
40+
self.validate_momentum(self.momentum)
41+
self.validate_beta(self.beta)
42+
self.validate_epsilon(self.eps)
43+
44+
def __str__(self) -> str:
45+
return 'SM3'
46+
47+
@torch.no_grad()
48+
def reset(self):
49+
for group in self.param_groups:
50+
for p in group['params']:
51+
state = self.state[p]
52+
53+
state['step'] = 0
54+
state['momentum_buffer'] = torch.zeros_like(p)
55+
56+
@staticmethod
57+
def max_reduce_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
58+
r"""Perform reduce-max along all dimensions except the given dim."""
59+
rank: int = len(x.shape)
60+
if rank == 0:
61+
return x
62+
63+
if dim >= rank:
64+
raise ValueError(f'[-] given dim is bigger than rank. {dim} >= {rank}')
65+
66+
for d in range(rank):
67+
if d != dim:
68+
x = x.max(dim=d, keepdim=True).values
69+
return x
70+
71+
@staticmethod
72+
def make_sparse(grad: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
73+
if grad._indices().dim() == 0 or values.dim() == 0:
74+
return grad.new().resize_as_(grad)
75+
return grad.new(grad._indices(), values, grad.size())
76+
77+
@torch.no_grad()
78+
def step(self, closure: CLOSURE = None) -> LOSS:
79+
loss: LOSS = None
80+
if closure is not None:
81+
with torch.enable_grad():
82+
loss = closure()
83+
84+
for group in self.param_groups:
85+
momentum, beta = group['momentum'], group['beta']
86+
for p in group['params']:
87+
if p.grad is None:
88+
continue
89+
90+
grad = p.grad
91+
92+
shape = grad.shape
93+
rank: int = len(shape)
94+
95+
state = self.state[p]
96+
if len(state) == 0:
97+
state['step'] = 0
98+
state['momentum_buffer'] = torch.zeros_like(p)
99+
100+
if grad.is_sparse:
101+
state['accumulator_0'] = torch.zeros(shape[0])
102+
elif rank == 0:
103+
state['accumulator_0'] = torch.zeros(shape)
104+
else:
105+
for i in range(rank):
106+
state[f'accumulator_{i}'] = torch.zeros([1] * i + [shape[i]] + [1] * (rank - 1 - i))
107+
108+
state['step'] += 1
109+
110+
if grad.is_sparse:
111+
grad = grad.coalesce()
112+
113+
acc = state['accumulator_0']
114+
update_values = torch.gather(acc, 0, grad._indices()[0])
115+
if beta > 0.0:
116+
update_values.mul_(beta)
117+
update_values.addcmul_(grad._values(), grad._values(), value=1.0 - beta)
118+
119+
nu_max = self.max_reduce_except_dim(
120+
x=self.make_sparse(grad, update_values).to_dense(),
121+
dim=0,
122+
).squeeze_()
123+
124+
if beta > 0.0:
125+
torch.max(acc, nu_max, out=acc)
126+
else:
127+
acc.copy_(nu_max)
128+
129+
update_values.add_(self.eps).rsqrt_().mul_(grad._values())
130+
131+
update = self.make_sparse(grad, update_values)
132+
else:
133+
update = state['accumulator_0'].clone()
134+
for i in range(1, rank):
135+
update = torch.min(update, state[f'accumulator_{i}'])
136+
137+
if beta > 0.0:
138+
update.mul_(beta)
139+
update.addcmul_(grad, grad, value=1.0 - beta)
140+
141+
for i in range(rank):
142+
acc = state[f'accumulator_{i}']
143+
nu_max = self.max_reduce_except_dim(update, i)
144+
if beta > 0.0:
145+
torch.max(acc, nu_max, out=acc)
146+
else:
147+
acc.copy_(nu_max)
148+
149+
update.add_(self.eps).rsqrt_().mul_(grad)
150+
151+
if momentum > 0.0:
152+
m = state['momentum_buffer']
153+
m.mul_(momentum).add_(update, alpha=1.0 - momentum)
154+
update = m
155+
156+
p.add_(update, alpha=-group['lr'])
157+
158+
return loss

requirements-dev.txt

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,29 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22

3-
attrs==22.2.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
43
black==23.3.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
54
click==8.1.3 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
65
colorama==0.4.6 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" and sys_platform == "win32" or python_full_version >= "3.7.2" and python_full_version < "4.0.0" and platform_system == "Windows"
7-
coverage[toml]==7.2.2 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
6+
coverage[toml]==7.2.3 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
87
exceptiongroup==1.1.1 ; python_full_version >= "3.7.2" and python_version < "3.11"
9-
filelock==3.10.7 ; python_version >= "3.8" and python_full_version < "4.0.0"
10-
importlib-metadata==6.1.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
8+
filelock==3.12.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
9+
importlib-metadata==6.5.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
1110
iniconfig==2.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
1211
isort==5.11.5 ; python_full_version >= "3.7.2" and python_version < "3.8"
1312
isort==5.12.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
1413
jinja2==3.1.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
1514
markupsafe==2.1.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
1615
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
1716
mypy-extensions==1.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
18-
networkx==3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
17+
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
1918
numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
2019
numpy==1.24.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
21-
packaging==23.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
20+
packaging==23.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2221
pathspec==0.11.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2322
platformdirs==3.2.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2423
pluggy==1.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2524
pytest-cov==4.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
26-
pytest==7.2.2 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
27-
ruff==0.0.260 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
25+
pytest==7.3.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
26+
ruff==0.0.262 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2827
sympy==1.11.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
2928
tomli==2.0.1 ; python_full_version >= "3.7.2" and python_full_version <= "3.11.0a6"
3029
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_version < "3.8"

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22

3-
filelock==3.10.7 ; python_version >= "3.8" and python_full_version < "4.0.0"
3+
filelock==3.12.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
44
jinja2==3.1.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
55
markupsafe==2.1.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
66
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
7-
networkx==3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
7+
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
88
numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
99
numpy==1.24.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
1010
sympy==1.11.1 ; python_version >= "3.8" and python_full_version < "4.0.0"

tests/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
OPTIMIZERS,
77
PNM,
88
SGDP,
9+
SM3,
910
AdaBelief,
1011
AdaBound,
1112
AdaFactor,
@@ -47,7 +48,7 @@
4748
'lookahead',
4849
]
4950

50-
SPARSE_OPTIMIZERS: List[str] = ['madgrad', 'dadaptadagrad']
51+
SPARSE_OPTIMIZERS: List[str] = ['madgrad', 'dadaptadagrad', 'sm3']
5152
NO_SPARSE_OPTIMIZERS: List[str] = [
5253
optimizer for optimizer in VALID_OPTIMIZER_NAMES if optimizer not in SPARSE_OPTIMIZERS
5354
]
@@ -300,6 +301,7 @@
300301
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 10),
301302
(AliG, {'max_lr': 5e-1, 'momentum': 0.9}, 10),
302303
(AliG, {'max_lr': 5e-1, 'momentum': 0.9, 'adjusted_momentum': True}, 10),
304+
(SM3, {'lr': 5e-1, 'momentum': 0.9, 'beta': 0.9}, 10),
303305
]
304306
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
305307
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),

tests/test_gradients.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,12 @@ def test_sparse(sparse_optimizer):
4949

5050
weight, weight_sparse = simple_sparse_parameter()
5151

52-
opt_dense = opt([weight], lr=1e-3, momentum=0.0)
53-
opt_sparse = opt([weight_sparse], lr=1e-3, momentum=0.0)
52+
params = {'lr': 1e-3, 'momentum': 0.0}
53+
if sparse_optimizer == 'sm3':
54+
params.update({'beta': 0.9})
55+
56+
opt_dense = opt([weight], **params)
57+
opt_sparse = opt([weight_sparse], **params)
5458

5559
opt_dense.step()
5660
opt_sparse.step()
@@ -89,13 +93,14 @@ def test_sparse_supported(sparse_optimizer):
8993
with pytest.raises(NoSparseGradientError):
9094
optimizer.step()
9195

92-
optimizer = opt([simple_sparse_parameter()[1]], momentum=0.9, weight_decay=1e-3)
93-
optimizer.reset()
94-
if sparse_optimizer == 'madgrad':
95-
with pytest.raises(NoSparseGradientError):
96+
if sparse_optimizer in ('madgrad', 'dadapt'):
97+
optimizer = opt([simple_sparse_parameter()[1]], momentum=0.9, weight_decay=1e-3)
98+
optimizer.reset()
99+
if sparse_optimizer == 'madgrad':
100+
with pytest.raises(NoSparseGradientError):
101+
optimizer.step()
102+
else:
96103
optimizer.step()
97-
else:
98-
optimizer.step()
99104

100105

101106
@pytest.mark.parametrize('optimizer_name', VALID_OPTIMIZER_NAMES)

0 commit comments

Comments
 (0)