Skip to content

Commit c33e53e

Browse files
authored
Merge pull request #170 from kozistr/fix/ranger21-wd
[Fix] weight decay in Ranger21 optimizer
2 parents 5ee6ed6 + 2f7b243 commit c33e53e

File tree

7 files changed

+34
-25
lines changed

7 files changed

+34
-25
lines changed

docs/changelogs/v2.9.1.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
## Change Log
2+
3+
### Fix
4+
5+
* fix weight decay in Ranger21 (#170)
6+
7+
## Diff
8+
9+
[2.9.0...2.9.1](https://github.com/kozistr/pytorch_optimizer/compare/v2.9.0...v2.9.1)

poetry.lock

Lines changed: 16 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.9.0"
4-
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
3+
version = "2.9.1"
4+
description = "optimizer & lr scheduler collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
77
maintainers = ["kozistr <[email protected]>"]

pytorch_optimizer/optimizer/agc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def agc(p: torch.Tensor, grad: torch.Tensor, agc_eps: float, agc_clip_val: float, eps: float = 1e-6) -> torch.Tensor:
77
r"""Clip gradient values in excess of the unit wise norm.
88
9-
:param p: Parameter. parameter.
9+
:param p: torch.Tensor. parameter.
1010
:param grad: torch.Tensor, gradient.
1111
:param agc_eps: float. agc epsilon to clip the norm of parameter.
1212
:param agc_clip_val: float. norm clip.

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
251251
self.apply_weight_decay(
252252
p=p,
253253
grad=None,
254-
lr=group['lr'],
254+
lr=lr,
255255
weight_decay=group['weight_decay'],
256256
weight_decouple=group['weight_decouple'],
257257
fixed_decay=group['fixed_decay'],

requirements-dev.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
1919
numpy==1.24.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
2020
packaging==23.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2121
pathspec==0.11.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
22-
platformdirs==3.5.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
22+
platformdirs==3.5.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2323
pluggy==1.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2424
pytest-cov==4.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2525
pytest==7.3.1 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2626
ruff==0.0.264 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
27-
sympy==1.11.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
27+
sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0"
2828
tomli==2.0.1 ; python_full_version >= "3.7.2" and python_full_version <= "3.11.0a6"
2929
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_version < "3.8"
30-
torch==2.0.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
30+
torch==2.0.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
3131
typed-ast==1.5.4 ; python_version < "3.8" and implementation_name == "cpython" and python_full_version >= "3.7.2"
3232
typing-extensions==4.5.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
3333
zipp==3.15.0 ; python_full_version >= "3.7.2" and python_version < "3.8"

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
77
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
88
numpy==1.21.1 ; python_full_version >= "3.7.2" and python_version < "3.8"
99
numpy==1.24.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
10-
sympy==1.11.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
10+
sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0"
1111
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_version < "3.8"
12-
torch==2.0.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
12+
torch==2.0.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
1313
typing-extensions==4.5.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"

0 commit comments

Comments
 (0)