Skip to content

Commit 5df1281

Browse files
authored
Merge pull request #97 from kozistr/feature/shampoo-optimizer
[Feature] Re-Implement Shampoo Optimizer w/ Grafting & Partitioner
2 parents 0567ae9 + e792181 commit 5df1281

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1038
-1148
lines changed

.pylintrc

Lines changed: 0 additions & 433 deletions
This file was deleted.

Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ test:
1212
python -m pytest -p no:pastebin -p no:nose -p no:doctest -sv -vv --cov=pytorch_optimizer --cov-report=xml ./tests
1313

1414
check:
15-
isort --check-only --profile black -l 119 pytorch_optimizer tests hubconf.py
1615
black -S -l 119 --check pytorch_optimizer tests hubconf.py
17-
pylint --fail-under=10.0 pytorch_optimizer
16+
ruff pytorch_optimizer tests hubconf.py
1817

1918
requirements:
2019
python -m poetry export -f requirements.txt --output requirements.txt --without-hashes

docs/util_api.rst

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@ get_optimizer_parameters
1717
.. autoclass:: pytorch_optimizer.get_optimizer_parameters
1818
:members:
1919

20-
.. _matrix_power:
21-
22-
matrix_power
23-
------------
24-
25-
.. autoclass:: pytorch_optimizer.matrix_power
26-
:members:
27-
2820
.. _normalize_gradient:
2921

3022
normalize_gradient

poetry.lock

Lines changed: 99 additions & 266 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.2.1"
3+
version = "2.3.0"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -42,7 +42,7 @@ torch = { version = "^1.10", source = "torch"}
4242
[tool.poetry.dev-dependencies]
4343
isort = "^5.11.4"
4444
black = "^22.12.0"
45-
pylint = "^2.15.9"
45+
ruff = "^0.0.237"
4646
pytest = "^7.2.0"
4747
pytest-cov = "^4.0.0"
4848

@@ -51,9 +51,50 @@ name = "torch"
5151
url = "https://download.pytorch.org/whl/cpu"
5252
secondary = true
5353

54+
[tool.ruff]
55+
select = ["A", "B", "C4", "E", "F", "G", "I", "N", "S", "T", "ISC", "W", "INP", "PIE", "T20", "RET", "SIM", "ARG"]
56+
ignore = []
57+
fixable = ["A", "B", "C", "D", "E", "F"]
58+
unfixable = ["F401"]
59+
exclude = [
60+
".eggs",
61+
".git",
62+
".mypy_cache",
63+
".ruff_cache",
64+
".github",
65+
".venv",
66+
"__pypackages__",
67+
"_build",
68+
"build",
69+
"dist",
70+
"node_modules",
71+
"venv",
72+
"docs",
73+
"assets",
74+
]
75+
line-length = 119
76+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
77+
target-version = "py39"
78+
79+
[tool.ruff.per-file-ignores]
80+
"./hubconf.py" = ["INP001"]
81+
"./tests/test_utils.py" = ["S101"]
82+
"./tests/test_gradients.py" = ["S101"]
83+
"./tests/test_optimizers.py" = ["S101"]
84+
"./tests/test_optimizer_parameters.py" = ["S101"]
85+
"./tests/test_load_optimizers.py" = ["S101"]
86+
"./tests/test_load_lr_schedulers.py" = ["S101"]
87+
"./tests/test_lr_scheduler_parameters.py" = ["S101"]
88+
"./pytorch_optimizer/__init__.py" = ["F401"]
89+
"./pytorch_optimizer/lr_scheduler/__init__.py" = ["F401"]
90+
91+
[tool.ruff.mccabe]
92+
max-complexity = 10
93+
5494
[tool.coverage.run]
5595
omit = [
5696
"./pytorch_optimizer/optimizer/gsam.py",
97+
"./pytorch_optimizer/optimizer/fp16.py",
5798
]
5899

59100
[build-system]

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=unused-import
1+
# ruff: noqa
22
from typing import Dict, List
33

44
from pytorch_optimizer.base.types import OPTIMIZER, SCHEDULER
@@ -44,7 +44,6 @@
4444
disable_running_stats,
4545
enable_running_stats,
4646
get_optimizer_parameters,
47-
matrix_power,
4847
normalize_gradient,
4948
unit_norm,
5049
)

pytorch_optimizer/base/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from pytorch_optimizer.base.exception import NegativeLRError
5+
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
66
from pytorch_optimizer.base.types import BETAS
77

88

@@ -90,7 +90,7 @@ def validate_reduction(reduction: str):
9090
@staticmethod
9191
def validate_update_frequency(update_frequency: int):
9292
if update_frequency < 1:
93-
raise ValueError(f'[-] update_frequency {update_frequency} must be positive')
93+
raise NegativeStepError(f'[-] update_frequency {update_frequency} must be positive')
9494

9595
@staticmethod
9696
def validate_norm(norm: float):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
# pylint: disable=unused-import
1+
# ruff: noqa
22
from torch.optim.lr_scheduler import ConstantLR, CosineAnnealingLR, CosineAnnealingWarmRestarts, CyclicLR, OneCycleLR

pytorch_optimizer/lr_scheduler/chebyshev.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33

44
def chebyshev_steps(small_m: float, big_m: float, num_epochs: int) -> np.ndarray:
5-
"""chebyshev_steps
5+
r"""chebyshev_steps
66
77
:param small_m: float. stands for 'm' notation.
88
:param big_m: float. stands for 'M' notation.
99
:param num_epochs: int. stands for 'T' notation.
10-
:return: np.array. chebyshev_steps
10+
:return: np.array. chebyshev_steps.
1111
"""
1212

1313
c, r = (big_m + small_m) / 2.0, (big_m - small_m) / 2.0
@@ -26,6 +26,4 @@ def chebyshev_perm(num_epochs: int) -> np.ndarray:
2626
def get_chebyshev_schedule(num_epochs: int) -> np.ndarray:
2727
steps: np.ndarray = chebyshev_steps(0.1, 1, num_epochs - 2)
2828
perm: np.ndarray = chebyshev_perm(num_epochs - 2)
29-
chebyshev_schedule = steps[perm]
30-
31-
return chebyshev_schedule
29+
return steps[perm]

pytorch_optimizer/optimizer/adabelief.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ def __init__(
5353

5454
self.validate_parameters()
5555

56-
defaults: DEFAULTS = dict(
57-
lr=lr,
58-
betas=betas,
59-
eps=eps,
60-
weight_decay=weight_decay,
61-
amsgrad=amsgrad,
62-
adamd_debias_term=adamd_debias_term,
63-
buffer=[[None, None, None] for _ in range(10)],
64-
)
56+
defaults: DEFAULTS = {
57+
'lr': lr,
58+
'betas': betas,
59+
'eps': eps,
60+
'weight_decay': weight_decay,
61+
'amsgrad': amsgrad,
62+
'adamd_debias_term': adamd_debias_term,
63+
'buffer': [[None, None, None] for _ in range(10)],
64+
}
6565
super().__init__(params, defaults)
6666

6767
def validate_parameters(self):
@@ -71,7 +71,7 @@ def validate_parameters(self):
7171
self.validate_epsilon(self.eps)
7272

7373
@property
74-
def __name__(self) -> str:
74+
def __str__(self) -> str:
7575
return 'AdaBelief'
7676

7777
@torch.no_grad()
@@ -106,7 +106,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
106106

107107
grad = p.grad
108108
if grad.is_sparse:
109-
raise NoSparseGradientError(self.__name__)
109+
raise NoSparseGradientError(self.__str__)
110110

111111
state = self.state[p]
112112
if len(state) == 0:

0 commit comments

Comments
 (0)