Skip to content

Commit 3d4d440

Browse files
authored
Merge pull request #258 from kozistr/feature/adalomo-optimizer
[Feature] Implement `AdaLOMO` optimizer and others
2 parents acd218b + 2acade3 commit 3d4d440

20 files changed

+675
-318
lines changed

README.md

Lines changed: 80 additions & 78 deletions
Large diffs are not rendered by default.

docs/changelogs/v3.1.0.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
## Change Log
2+
3+
### Feature
4+
5+
* Implement `AdaLomo` optimizer. (#258)
6+
* [Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/abs/2310.10195)
7+
* Support `Q-GaLore` optimizer. (#258)
8+
* [Q-GaLore: Quantized GaLore with INT4 Projection and Layer-Adaptive Low-Rank Gradients.](https://arxiv.org/abs/2407.08296)
9+
* you can use by `optimizer = load_optimizer('q_galore_adamw8bit')`
10+
* Support more bnb optimizers. (#258)
11+
* `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`.
12+
13+
### Refactor
14+
15+
* Refactor `AdamMini`. (#258)
16+
* Deprecate optional dependency, `bitsandbytes`. (#258)
17+
* Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258)
18+
19+
### Bug
20+
21+
* Fix several bugs in `AdamMini` optimizer. (#257)
22+
23+
## Contributions
24+
25+
thanks to @sdbds

docs/index.md

Lines changed: 85 additions & 79 deletions
Large diffs are not rendered by default.

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
:docstring:
2525
:members:
2626

27+
::: pytorch_optimizer.AdaLOMO
28+
:docstring:
29+
:members:
30+
2731
::: pytorch_optimizer.Adai
2832
:docstring:
2933
:members:

examples/visualize_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def main():
158158
]
159159

160160
for optimizer_name, optimizer in OPTIMIZERS.items():
161-
if optimizer_name.lower() in {'alig', 'lomo', 'bsam', 'adammini'}:
161+
if optimizer_name.lower() in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini'}:
162162
continue
163163

164164
optimizers.append((optimizer, -6, 0.2))

poetry.lock

Lines changed: 65 additions & 81 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.0.2"
3+
version = "3.1.0"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -12,13 +12,14 @@ documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
1212
keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
15-
"AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME",
16-
"DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore",
17-
"Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero",
18-
"NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad",
19-
"SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
20-
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine",
21-
"SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
15+
"AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM",
16+
"CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage",
17+
"GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG",
18+
"Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21",
19+
"RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
20+
"SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
21+
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
22+
"bitsandbytes", "WSD", "QGaLore",
2223
]
2324
classifiers = [
2425
"License :: OSI Approved :: Apache Software License",
@@ -46,7 +47,6 @@ classifiers = [
4647
python = ">=3.8,<4.0.0"
4748
numpy = { version = "*", python = ">=3.8" }
4849
torch = { version = ">=1.10", python = ">=3.8", source = "torch" }
49-
bitsandbytes = { version = "^0.43", optional = true }
5050

5151
[tool.poetry.dev-dependencies]
5252
isort = { version = "^5", python = ">=3.8" }
@@ -55,9 +55,6 @@ ruff = "*"
5555
pytest = "*"
5656
pytest-cov = "*"
5757

58-
[tool.poetry.extras]
59-
bitsandbytes = ["bitsandbytes"]
60-
6158
[[tool.poetry.source]]
6259
name = "torch"
6360
url = "https://download.pytorch.org/whl/cpu"

pytorch_optimizer/__init__.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# ruff: noqa
2+
from importlib.util import find_spec
23
from typing import Dict, List
34

45
import torch.cuda
@@ -72,7 +73,7 @@
7273
from pytorch_optimizer.optimizer.lamb import Lamb
7374
from pytorch_optimizer.optimizer.lars import LARS
7475
from pytorch_optimizer.optimizer.lion import Lion
75-
from pytorch_optimizer.optimizer.lomo import LOMO
76+
from pytorch_optimizer.optimizer.lomo import LOMO, AdaLOMO
7677
from pytorch_optimizer.optimizer.lookahead import Lookahead
7778
from pytorch_optimizer.optimizer.madgrad import MADGRAD
7879
from pytorch_optimizer.optimizer.msvag import MSVAG
@@ -126,12 +127,8 @@
126127
)
127128
from pytorch_optimizer.optimizer.yogi import Yogi
128129

129-
try:
130-
import bitsandbytes as bnb
131-
132-
HAS_BNB: bool = True # pragma: no cover
133-
except ImportError:
134-
HAS_BNB: bool = False
130+
HAS_BNB: bool = find_spec('bitsandbytes') is not None
131+
HAS_Q_GALORE: bool = find_spec('q-galore-torch') is not None
135132

136133
OPTIMIZER_LIST: List[OPTIMIZER] = [
137134
AdaBelief,
@@ -205,6 +202,7 @@
205202
Kate,
206203
StableAdamW,
207204
AdamMini,
205+
AdaLOMO,
208206
]
209207
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
210208

@@ -252,22 +250,58 @@
252250

253251
def load_bnb_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
254252
r"""load bnb optimizer instance."""
253+
from bitsandbytes import optim
254+
255255
if 'sgd8bit' in optimizer:
256-
return bnb.optim.SGD8bit
256+
return optim.SGD8bit
257257
if 'adam8bit' in optimizer:
258-
return bnb.optim.Adam8bit
258+
return optim.Adam8bit
259+
if 'paged_adam8bit' in optimizer:
260+
return optim.PagedAdam8bit
259261
if 'adamw8bit' in optimizer:
260-
return bnb.optim.AdamW8bit
262+
return optim.AdamW8bit
263+
if 'paged_adamw8bit' in optimizer:
264+
return optim.PagedAdamW8bit
261265
if 'lamb8bit' in optimizer:
262-
return bnb.optim.LAMB8bit
266+
return optim.LAMB8bit
263267
if 'lars8bit' in optimizer:
264-
return bnb.optim.LARS8bit
268+
return optim.LARS8bit
265269
if 'lion8bit' in optimizer:
266-
return bnb.optim.Lion8bit
270+
return optim.Lion8bit
267271
if 'adagrad8bit' in optimizer:
268-
return bnb.optim.Adagrad8bit
272+
return optim.Adagrad8bit
269273
if 'rmsprop8bit' in optimizer:
270-
return bnb.optim.RMSprop8bit
274+
return optim.RMSprop8bit
275+
if 'adagrad32bit' in optimizer:
276+
return optim.Adagrad32bit
277+
if 'adam32bit' in optimizer:
278+
return optim.Adam32bit
279+
if 'paged_adam32bit' in optimizer:
280+
return optim.PagedAdam32bit
281+
if 'adamw32bit' in optimizer:
282+
return optim.AdamW32bit
283+
if 'lamb32bit' in optimizer:
284+
return optim.LAMB32bit
285+
if 'lars32bit' in optimizer:
286+
return optim.LARS32bit
287+
if 'lion32bit' in optimizer:
288+
return optim.Lion32bit
289+
if 'paged_lion32bit' in optimizer:
290+
return optim.PagedLion32bit
291+
if 'rmsprop32bit' in optimizer:
292+
return optim.RMSprop32bit
293+
if 'sgd32bit' in optimizer:
294+
return optim.SGD32bit
295+
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
296+
297+
298+
def load_q_galore_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
299+
r"""load Q-GaLore optimizer instance."""
300+
import q_galore_torch
301+
302+
if 'adamw8bit' in optimizer:
303+
return q_galore_torch.QGaLoreAdamW8bit
304+
271305
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
272306

273307

@@ -277,7 +311,11 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
277311
if optimizer.startswith('bnb'):
278312
if HAS_BNB and torch.cuda.is_available():
279313
return load_bnb_optimizer(optimizer) # pragma: no cover
280-
raise ImportError(f'[-] bitsandbytes and CUDA required for bnb optimizers : {optimizer}')
314+
raise ImportError(f'[-] bitsandbytes and CUDA required for the optimizer {optimizer}')
315+
if optimizer.startswith('q_galore'):
316+
if HAS_Q_GALORE and torch.cuda.is_available():
317+
return load_q_galore_optimizer(optimizer) # pragma: no cover
318+
raise ImportError(f'[-] bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer}')
281319
if optimizer not in OPTIMIZERS:
282320
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
283321

pytorch_optimizer/base/optimizer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,22 @@ def get_adanorm_gradient(
214214

215215
return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad
216216

217+
@staticmethod
218+
def get_rms(x: torch.Tensor) -> float:
219+
r"""Get RMS."""
220+
return x.norm(2) / math.sqrt(x.numel())
221+
222+
@staticmethod
223+
def approximate_sq_grad(
224+
exp_avg_sq_row: torch.Tensor,
225+
exp_avg_sq_col: torch.Tensor,
226+
output: torch.Tensor,
227+
) -> None:
228+
r"""Get approximation of EMA of squared gradient."""
229+
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
230+
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
231+
torch.mul(r_factor, c_factor, out=output)
232+
217233
@staticmethod
218234
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
219235
if range_type == '[)' and not low <= x < high:

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,22 +127,6 @@ def get_options(shape: Tuple[int, ...]) -> bool:
127127
r"""Get `factored`."""
128128
return len(shape) >= 2
129129

130-
@staticmethod
131-
def get_rms(x: torch.Tensor) -> float:
132-
r"""Get RMS."""
133-
return x.norm(2) / math.sqrt(x.numel())
134-
135-
@staticmethod
136-
def approximate_sq_grad(
137-
exp_avg_sq_row: torch.Tensor,
138-
exp_avg_sq_col: torch.Tensor,
139-
output: torch.Tensor,
140-
):
141-
r"""Get approximation of EMA of squared gradient."""
142-
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
143-
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
144-
torch.mul(r_factor, c_factor, out=output)
145-
146130
@torch.no_grad()
147131
def step(self, closure: CLOSURE = None) -> LOSS:
148132
loss: LOSS = None

0 commit comments

Comments
 (0)