Skip to content

Commit 84b926c

Browse files
authored
Merge pull request #374 from kozistr/release/v3.5.1
[Release] v3.5.1
2 parents 0ad115c + 29669b0 commit 84b926c

File tree

7 files changed

+71
-88
lines changed

7 files changed

+71
-88
lines changed

poetry.lock

Lines changed: 47 additions & 65 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.5.0"
3+
version = "3.5.1"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/optimizer/spam.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def step(self, current_step: int) -> None:
3131
3232
:param current_step: int. Current step index.
3333
"""
34-
self.cosine_stepper.step(current_step)
34+
self.cosine_stepper.last_epoch = current_step
35+
self.cosine_stepper.step()
3536

3637
def get_death_rate(self, current_step: int) -> float:
3738
r"""Get the updated rate (death_rate) at the given step.
@@ -266,9 +267,9 @@ class StableSPAM(BaseOptimizer):
266267
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
267268
:param lr: float. learning rate.
268269
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
269-
:param gamma1: float.
270-
:param gamma2: float.
271-
:param theta: float.
270+
:param gamma1: float. gamma1 parameter.
271+
:param gamma2: float. gamma2 parameter.
272+
:param theta: float. theta parameter.
272273
:param t_max: Optional[int]. total number of steps.
273274
:param eta_min: float. eta_min of CosineDecay.
274275
:param weight_decay: float. weight decay (L2 penalty).

requirements-dev.txt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jinja2==3.1.6 ; python_version >= "3.8"
1818
markupsafe==2.1.5 ; python_version == "3.8"
1919
markupsafe==3.0.2 ; python_version >= "3.9"
2020
mpmath==1.3.0 ; python_version >= "3.8"
21-
mypy-extensions==1.0.0 ; python_version >= "3.8"
21+
mypy-extensions==1.1.0 ; python_version >= "3.8"
2222
networkx==3.1 ; python_version == "3.8"
2323
networkx==3.2.1 ; python_version >= "3.9"
2424
numpy==1.24.4 ; python_version == "3.8"
@@ -30,11 +30,10 @@ platformdirs==4.3.7 ; python_version >= "3.9"
3030
pluggy==1.5.0 ; python_version >= "3.8"
3131
pytest-cov==5.0.0 ; python_version >= "3.8"
3232
pytest==8.3.5 ; python_version >= "3.8"
33-
ruff==0.11.6 ; python_version >= "3.8"
34-
setuptools==79.0.0 ; python_version >= "3.12"
35-
sympy==1.13.1 ; python_version >= "3.9"
36-
sympy==1.13.3 ; python_version == "3.8"
33+
ruff==0.11.7 ; python_version >= "3.8"
34+
setuptools==79.0.1 ; python_version >= "3.12"
35+
sympy==1.13.3 ; python_version >= "3.8"
3736
tomli==2.2.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
3837
torch==2.4.1+cpu ; python_version == "3.8"
39-
torch==2.6.0+cpu ; python_version >= "3.9"
38+
torch==2.7.0+cpu ; python_version >= "3.9"
4039
typing-extensions==4.13.2 ; python_version >= "3.8"

requirements.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ networkx==3.1 ; python_version == "3.8"
1212
networkx==3.2.1 ; python_version >= "3.9"
1313
numpy==1.24.4 ; python_version == "3.8"
1414
numpy==2.0.2 ; python_version >= "3.9"
15-
setuptools==79.0.0 ; python_version >= "3.12"
16-
sympy==1.13.1 ; python_version >= "3.9"
17-
sympy==1.13.3 ; python_version == "3.8"
15+
setuptools==79.0.1 ; python_version >= "3.12"
16+
sympy==1.13.3 ; python_version >= "3.8"
1817
torch==2.4.1+cpu ; python_version == "3.8"
19-
torch==2.6.0+cpu ; python_version >= "3.9"
18+
torch==2.7.0+cpu ; python_version >= "3.9"
2019
typing-extensions==4.13.2 ; python_version >= "3.8"

tests/test_lr_schedulers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,19 @@ def test_get_chebyshev_lr():
200200
optimizer.step()
201201

202202
lr_scheduler = get_chebyshev_schedule(optimizer, num_epochs=16, is_warmup=True)
203-
lr_scheduler.step(0)
203+
lr_scheduler.last_epoch = 0
204+
lr_scheduler.step()
204205

205206
np.testing.assert_almost_equal(lr_scheduler.get_last_lr(), 1e-3)
206207

207208
optimizer = AdamW(Example().parameters())
208209
optimizer.step()
209210

210211
lr_scheduler = get_chebyshev_schedule(optimizer, num_epochs=16, is_warmup=False)
212+
lr_scheduler.last_epoch = 0
211213

212-
for i, expected_lr in enumerate(recipes, start=1):
213-
lr_scheduler.step(i)
214+
for expected_lr in recipes:
215+
lr_scheduler.step()
214216
np.testing.assert_almost_equal(lr_scheduler.get_last_lr(), expected_lr)
215217

216218

@@ -311,10 +313,10 @@ def test_wsd_lr_scheduler():
311313

312314
lr_scheduler = get_wsd_schedule(optimizer, 2, 2, 3, min_lr_ratio=0.1)
313315

314-
expected_lrs = [0.0, 0.0005, 0.001, 0.001, 0.001, 0.000775, 0.000325, 0.0001, 0.0001, 0.0001]
316+
expected_lrs = [0.0005, 0.001, 0.001, 0.001, 0.000775, 0.000325, 0.0001, 0.0001, 0.0001]
315317

316-
for step, expected_lr in enumerate(expected_lrs):
317-
lr_scheduler.step(step)
318+
for expected_lr in expected_lrs:
319+
lr_scheduler.step()
318320
np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_last_lr(), 6)
319321

320322

tests/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ def test_version_utils():
247247
with pytest.raises(ValueError):
248248
parse_pytorch_version('a.s.d.f')
249249

250-
assert parse_pytorch_version(torch.__version__) == [2, 6, 0]
250+
assert parse_pytorch_version(torch.__version__) == [2, 7, 0]
251251

252-
assert compare_versions('2.6.0', '2.4.0') >= 0
252+
assert compare_versions('2.7.0', '2.4.0') >= 0
253253

254254

255255
def test_cpu_offload_optimizer():

0 commit comments

Comments
 (0)