Skip to content

Commit fff34af

Browse files
authored
Merge pull request #106 from kozistr/deps/pytorch-version
[Deps] Support Pytorch 2.0
2 parents 06dce18 + 0f27033 commit fff34af

File tree

6 files changed

+120
-102
lines changed

6 files changed

+120
-102
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Install
3131

3232
$ pip3 install -U pytorch-optimizer
3333

34-
or
34+
If there's a version issue when installing the package, try with `--no-deps` option.
3535

3636
::
3737

poetry.lock

Lines changed: 92 additions & 78 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.4.1"
3+
version = "2.4.2"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -34,16 +34,16 @@ classifiers = [
3434
[tool.poetry.dependencies]
3535
python = "^3.7.2"
3636
numpy = [
37-
{ version = "=1.21.1", python = ">=3.7,<3.8"},
38-
{ version = "*", python = ">=3.8"},
37+
{ version = "=1.21.1", python = ">=3.7,<3.8" },
38+
{ version = "*", python = ">=3.8" },
3939
]
40-
torch = { version = "^1.10", source = "torch"}
40+
torch = { version = ">=1.10", source = "torch" }
4141

4242
[tool.poetry.dev-dependencies]
43-
isort = "^5.11.4"
44-
black = "^22.12.0"
45-
ruff = "^0.0.237"
46-
pytest = "^7.2.0"
43+
isort = "^5.11.5"
44+
black = "^23.1.0"
45+
ruff = "^0.0.244"
46+
pytest = "^7.2.1"
4747
pytest-cov = "^4.0.0"
4848

4949
[[tool.poetry.source]]
@@ -53,7 +53,7 @@ secondary = true
5353

5454
[tool.ruff]
5555
select = ["A", "B", "C4", "D", "E", "F", "G", "I", "N", "S", "T", "ISC", "ICN", "W", "INP", "PIE", "T20", "RET", "SIM", "TID", "ARG", "ERA", "RUF", "YTT", "PL"]
56-
ignore = ["D100", "D102", "D104", "D105", "D107", "D203", "D213", "PIE790", "PLR2004"]
56+
ignore = ["D100", "D102", "D104", "D105", "D107", "D203", "D213", "PIE790", "PLR0912", "PLR0913", "PLR0915", "PLR2004"]
5757
fixable = ["A", "B", "C", "D", "E", "F"]
5858
unfixable = ["F401"]
5959
exclude = [

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
119119

120120
pre_cond.add_(grad @ grad_t)
121121
if state['step'] % self.preconditioning_compute_steps == 0:
122-
inv_pre_cond = compute_power_svd(pre_cond, -1.0 / order)
122+
inv_pre_cond.copy_(compute_power_svd(pre_cond, -1.0 / order))
123123

124124
if dim_id == order - 1:
125125
grad = grad_t @ inv_pre_cond
@@ -151,8 +151,12 @@ class ScalableShampoo(Optimizer, BaseOptimizer):
151151
:param inverse_exponent_override: int. fixed exponent for pre-conditioner, if > 0.
152152
:param start_preconditioning_step: int.
153153
:param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
154-
requirements. How often to compute pre-conditioner.
155-
:param statistics_compute_steps: int. How often to compute statistics.
154+
requirements. How often to compute pre-conditioner. Ideally, 1 is the best. However, the current implementation
155+
doesn't work on the distributed environment (there are no statistics & pre-conditioners sync among replicas),
156+
compute on the GPU (not CPU) and the precision is fp32 (not fp64).
157+
Also, followed by the paper, `preconditioning_compute_steps` does not have a significant effect on the
158+
performance. So, If you have a problem with the speed, try to set this step bigger (e.g. 1000).
159+
:param statistics_compute_steps: int. How often to compute statistics. usually set to 1 (or 10).
156160
:param block_size: int. Block size for large layers (if > 0).
157161
Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!)
158162
Block size should be as large as feasible under memory/time constraints.
@@ -166,8 +170,8 @@ class ScalableShampoo(Optimizer, BaseOptimizer):
166170
:param diagonal_eps: float. term added to the denominator to improve numerical stability.
167171
:param matrix_eps: float. term added to the denominator to improve numerical stability.
168172
:param use_svd: bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}.
169-
Theoretically, Schur-Newton method is faster than SVD method to calculate M^{-1/p}.
170-
However, the inefficiency of the loop code, SVD is much faster than that.
173+
Theoretically, Schur-Newton method is faster than SVD method. However, the inefficiency of the loop code and
174+
proper svd kernel, SVD is much faster in some cases (usually in case of small models).
171175
see https://github.com/kozistr/pytorch_optimizer/pull/103
172176
"""
173177

requirements-dev.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22

33
attrs==22.2.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
4-
black==22.12.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
4+
black==23.1.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
55
click==8.1.3 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
66
colorama==0.4.6 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" and sys_platform == "win32" or python_full_version >= "3.7.2" and python_full_version < "4.0.0" and platform_system == "Windows"
77
coverage[toml]==7.1.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
88
exceptiongroup==1.1.0 ; python_full_version >= "3.7.2" and python_version < "3.11"
99
importlib-metadata==6.0.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
1010
iniconfig==2.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
11-
isort==5.11.4 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
12-
mypy-extensions==0.4.3 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
11+
isort==5.11.5 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
12+
mypy-extensions==1.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
1313
numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
14-
numpy==1.24.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
14+
numpy==1.24.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
1515
packaging==23.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
1616
pathspec==0.11.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
17-
platformdirs==2.6.2 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
17+
platformdirs==3.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
1818
pluggy==1.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
1919
pytest-cov==4.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2020
pytest==7.2.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
21-
ruff==0.0.237 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
22-
tomli==2.0.1 ; python_full_version >= "3.7.2" and python_full_version < "3.11.0a7"
21+
ruff==0.0.244 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
22+
tomli==2.0.1 ; python_full_version >= "3.7.2" and python_full_version <= "3.11.0a6"
2323
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2424
typed-ast==1.5.4 ; python_version < "3.8" and implementation_name == "cpython" and python_full_version >= "3.7.2"
2525
typing-extensions==4.4.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
26-
zipp==3.12.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
26+
zipp==3.13.0 ; python_full_version >= "3.7.2" and python_version < "3.8"

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22

33
numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
4-
numpy==1.24.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
4+
numpy==1.24.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
55
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
66
typing-extensions==4.4.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"

0 commit comments

Comments
 (0)