Skip to content

Commit 55d8ba5

Browse files
authored
Merge pull request #188 from kozistr/feature/lomo-optimizer
[Feature] Implement LOMO optimizer
2 parents 8259768 + ed73c36 commit 55d8ba5

20 files changed

+371
-110
lines changed

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, 56 optimizers, 6 lr schedulers are supported!
19+
| Currently, 57 optimizers, 6 lr schedulers are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -218,6 +218,8 @@ You can check the supported optimizers with below code.
218218
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
219219
| PAdam | *Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks* | `github <https://github.com/uclaml/Padam>`__ | `https://arxiv.org/abs/1806.06763 <https://arxiv.org/abs/1806.06763>`__ | `cite <https://github.com/uclaml/Padam#citation>`__ |
220220
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
221+
| LOMO | *Full Parameter Fine-tuning for Large Language Models with Limited Resources* | `github <https://github.com/OpenLMLab/LOMO>`__ | `https://arxiv.org/abs/2306.09782 <https://arxiv.org/abs/2306.09782>`__ | `cite <https://github.com/OpenLMLab/LOMO#citation>`__ |
222+
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
221223

222224
Supported LR Scheduler
223225
----------------------

docs/changelogs/v2.11.0.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* Implement PAdam optimizer (#186)
66
* [Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks](https://arxiv.org/abs/1806.06763)
7+
* Implement LOMO optimizer (#188)
8+
* [Full Parameter Fine-tuning for Large Language Models with Limited Resources](https://arxiv.org/abs/2306.09782)
79

810
### Diff
911

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,11 @@ PAdam
504504

505505
.. autoclass:: pytorch_optimizer.PAdam
506506
:members:
507+
508+
.. _LOMO:
509+
510+
LOMO
511+
----
512+
513+
.. autoclass:: pytorch_optimizer.LOMO
514+
:members:

poetry.lock

Lines changed: 43 additions & 74 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
1515
"AdamS", "Adan", "AggMo", "AliG", "Amos", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan",
16-
"DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero",
17-
"NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad",
18-
"SAM", "SGDP", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi",
16+
"DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD",
17+
"MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
18+
"Ranger21", "RotoGrad", "SAM", "SGDP", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "ScalableShampoo",
19+
"Shampoo", "Yogi",
1920
]
2021
classifiers = [
2122
"License :: OSI Approved :: Apache Software License",
@@ -55,18 +56,24 @@ isort = [
5556
{ version = "^5.12.0", python = ">=3.8"}
5657
]
5758
black = "^23.3.0"
58-
ruff = "^0.0.272"
59-
pytest = "^7.3.1"
59+
ruff = "^0.0.275"
60+
pytest = "^7.4.0"
6061
pytest-cov = "^4.1.0"
6162

6263
[[tool.poetry.source]]
6364
name = "torch"
6465
url = "https://download.pytorch.org/whl/cpu"
65-
secondary = true
66+
priority = "explicit"
6667

6768
[tool.ruff]
68-
select = ["A", "B", "C4", "D", "E", "F", "G", "I", "N", "S", "T", "ISC", "ICN", "W", "INP", "PIE", "T20", "RET", "SIM", "TID", "ARG", "ERA", "RUF", "YTT", "PL"]
69-
ignore = ["D100", "D102", "D104", "D105", "D107", "D203", "D213", "PIE790", "PLR0912", "PLR0913", "PLR0915", "PLR2004"]
69+
select = [
70+
"A", "B", "C4", "D", "E", "F", "G", "I", "N", "S", "T", "ISC", "ICN", "W", "INP", "PIE", "T20", "RET", "SIM",
71+
"TID", "ARG", "ERA", "RUF", "YTT", "PL",
72+
]
73+
ignore = [
74+
"B905", "D100", "D102", "D104", "D105", "D107", "D203", "D213", "PIE790", "PLR0912", "PLR0913", "PLR0915",
75+
"PLR2004", "RUF013",
76+
]
7077
fixable = ["A", "B", "C", "D", "E", "F"]
7178
unfixable = ["F401"]
7279
exclude = [
@@ -84,7 +91,7 @@ exclude = [
8491
]
8592
line-length = 119
8693
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
87-
target-version = "py39"
94+
target-version = "py311"
8895

8996
[tool.ruff.per-file-ignores]
9097
"./hubconf.py" = ["D", "INP001"]

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pytorch_optimizer.optimizer.lamb import Lamb
4949
from pytorch_optimizer.optimizer.lars import LARS
5050
from pytorch_optimizer.optimizer.lion import Lion
51+
from pytorch_optimizer.optimizer.lomo import LOMO
5152
from pytorch_optimizer.optimizer.lookahead import Lookahead
5253
from pytorch_optimizer.optimizer.madgrad import MADGRAD
5354
from pytorch_optimizer.optimizer.msvag import MSVAG
@@ -156,6 +157,7 @@
156157
SignSGD,
157158
Prodigy,
158159
PAdam,
160+
LOMO,
159161
]
160162
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
161163

pytorch_optimizer/base/optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def validate_range(x: float, name: str, low: float, high: float, range_type: str
226226
raise ValueError(f'[-] {name} must be in the range ({low}, {high})')
227227

228228
@staticmethod
229-
def validate_non_negative(x: float, name: str):
230-
if x < 0.0:
229+
def validate_non_negative(x: Optional[float], name: str):
230+
if x is not None and x < 0.0:
231231
raise ValueError(f'[-] {name} must be non-negative')
232232

233233
@staticmethod
@@ -276,5 +276,5 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]):
276276
self.validate_range(nus[1], 'nu2', 0.0, 1.0, range_type='[]')
277277

278278
@abstractmethod
279-
def reset(self):
279+
def reset(self): # pragma: no cover
280280
raise NotImplementedError

pytorch_optimizer/base/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def step(self):
8686
return value
8787

8888
@abstractmethod
89-
def _step(self) -> float:
89+
def _step(self) -> float: # pragma: no cover
9090
raise NotImplementedError
9191

9292
def get_lr(self) -> float:

pytorch_optimizer/base/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Type, Union
1+
from typing import Callable, Dict, Iterable, Literal, Optional, Tuple, Type, Union
22

33
import torch
44
from torch.optim import Optimizer
@@ -7,9 +7,9 @@
77
CLOSURE = Optional[Callable[[], float]]
88
LOSS = Optional[float]
99
BETAS = Union[Tuple[float, float], Tuple[float, float, float]]
10-
DEFAULTS = Dict[str, Any]
11-
PARAMETERS = Optional[Union[Iterable[Dict[str, Any]], Iterable[torch.Tensor]]]
12-
STATE = Dict[str, Any]
10+
DEFAULTS = Dict
11+
PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]]
12+
STATE = Dict
1313
OPTIMIZER = Type[Optimizer]
1414
SCHEDULER = Type[_LRScheduler]
1515

pytorch_optimizer/optimizer/fp16.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
self.last_overflow_iter: int = -1
4949
self.last_rescale_iter: int = -1
5050
self.overflows_since_rescale: int = 0
51+
self.has_overflow_serial: bool = False
5152

5253
def update_scale(self, overflow: bool):
5354
r"""Update the loss scale.

0 commit comments

Comments
 (0)