Skip to content

Commit 53a2880

Browse files
authored
[Fix] type-hint (#404)
* fix: type-hint * docs: v3.6.2 changelog * build(deps): dev dependencies * fix: type hint * docs: v3.7.0 changelog * update: test_version_utils * fix: test_version_utils * docs: v3.7.0 changelog
1 parent 59522cc commit 53a2880

File tree

7 files changed

+117
-100
lines changed

7 files changed

+117
-100
lines changed

docs/changelogs/v3.6.2.md renamed to docs/changelogs/v3.7.0.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
* Implement `EmoNavi`, `EmoFact`, and `EmoLynx` optimizers. (#393, #400)
1010
* [An emotion-driven optimizer that feels loss and navigates accordingly](https://github.com/muooon/EmoNavi)
1111

12+
### CI
13+
14+
* Enable CI for Python 3.8 ~ 3.13. (#402, #404)
15+
1216
### Fix
1317

1418
* Adjust the value of `eps` to the fixed value `1e-15` when adding to `exp_avg_sq`. (#397, #398)
19+
* built-in type-hint in `Kron` optimizer. (#404)

poetry.lock

Lines changed: 90 additions & 90 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,12 @@ black = [
6565
{ version = "<25", python = "<3.9" },
6666
{ version = "^25", python = ">=3.9" },
6767
]
68-
ruff = "*"
69-
pytest = "*"
70-
pytest-cov = "*"
68+
ruff = "^0.12"
69+
pytest = "^8"
70+
pytest-cov = [
71+
{ version = "<6", python = "<3.9" },
72+
{ version = "^6", python = ">=3.9" },
73+
]
7174

7275
[[tool.poetry.source]]
7376
name = "torch"
@@ -139,5 +142,5 @@ omit = [
139142
]
140143

141144
[build-system]
142-
requires = ["poetry-core>=1.4.0"]
145+
requires = ["poetry-core>=2.0.0"]
143146
build-backend = "poetry.core.masonry.api"

pytorch_optimizer/optimizer/psgd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def initialize_q_expressions(
222222
dtype: torch.dtype = dtype if dtype is not None else t.dtype
223223
shape = t.shape
224224
if len(shape) == 0:
225-
qs: list[torch.Tensor] = [scale * torch.ones_like(t, dtype=dtype)]
225+
qs: List[torch.Tensor] = [scale * torch.ones_like(t, dtype=dtype)]
226226
expressions_a: str = ',->'
227227
expression_gr: List[str] = [',->']
228228
expression_r: str = ',,->'
@@ -370,6 +370,6 @@ def update_precondition(
370370
q.sub_(tmp)
371371

372372

373-
def get_precondition_grad(qs: list[torch.Tensor], expressions: list[str], g: torch.Tensor) -> torch.Tensor:
373+
def get_precondition_grad(qs: List[torch.Tensor], expressions: List[str], g: torch.Tensor) -> torch.Tensor:
374374
r"""Precondition gradient G with pre-conditioner Q."""
375375
return torch.einsum(expressions[-1], *[x.conj() for x in qs], *qs, g)

pytorch_optimizer/optimizer/splus.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
import torch
24

35
from pytorch_optimizer.base.exception import NoComplexParameterError, NoSparseGradientError
@@ -122,7 +124,7 @@ def init_group(self, group: GROUP, **kwargs) -> None:
122124
]
123125

124126
@staticmethod
125-
def get_scaled_lr(shape: tuple[int, int], lr: float, nonstandard_constant: float, max_dim: int = 10000) -> float:
127+
def get_scaled_lr(shape: Tuple[int, int], lr: float, nonstandard_constant: float, max_dim: int = 10000) -> float:
126128
scale: float = (
127129
nonstandard_constant
128130
if len(shape) != 2 or shape[0] > max_dim or shape[1] > max_dim

requirements-dev.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ black==24.8.0 ; python_version == "3.8"
44
black==25.1.0 ; python_version >= "3.9"
55
click==8.1.8 ; python_version >= "3.8"
66
colorama==0.4.6 ; python_version >= "3.8" and (sys_platform == "win32" or platform_system == "Windows")
7-
coverage[toml]==7.10.0 ; python_version >= "3.9"
7+
coverage[toml]==7.10.1 ; python_version >= "3.9"
88
coverage[toml]==7.6.1 ; python_version == "3.8"
99
exceptiongroup==1.3.0 ; python_version < "3.11" and python_version >= "3.8"
1010
filelock==3.16.1 ; python_version == "3.8"
@@ -29,7 +29,8 @@ platformdirs==4.3.6 ; python_version == "3.8"
2929
platformdirs==4.3.8 ; python_version >= "3.9"
3030
pluggy==1.5.0 ; python_version == "3.8"
3131
pluggy==1.6.0 ; python_version >= "3.9"
32-
pytest-cov==5.0.0 ; python_version >= "3.8"
32+
pytest-cov==5.0.0 ; python_version == "3.8"
33+
pytest-cov==6.2.1 ; python_version >= "3.9"
3334
pytest==8.3.5 ; python_version >= "3.8"
3435
ruff==0.12.5 ; python_version >= "3.8"
3536
setuptools==80.9.0 ; python_version >= "3.12"

tests/test_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from typing import List
23

34
import numpy as np
@@ -250,7 +251,12 @@ def test_version_utils():
250251
with pytest.raises(ValueError):
251252
parse_pytorch_version('a.s.d.f')
252253

253-
assert parse_pytorch_version(torch.__version__) == [2, 7, 1]
254+
python_version = sys.version_info
255+
256+
if python_version.minor < 9:
257+
assert parse_pytorch_version(torch.__version__) == [2, 4, 1]
258+
else:
259+
assert parse_pytorch_version(torch.__version__) == [2, 7, 1]
254260

255261
assert compare_versions('2.7.0', '2.4.0') >= 0
256262

0 commit comments

Comments
 (0)