Skip to content

Commit 77098e9

Browse files
authored
[Release] v3.6.1 (#392)
* build(version): v3.6.1 * build(deps): update packages * style: PLC0415 * style: fix RUF005 * update: beta range to `[0, 1)` * docs: v3.6.1 changelog * update: default beta3 range to `[0, 1]` * update: validate_betas * update: validate_betas * update: OPTIMIZERS_IGNORE * update: test_version_utils
1 parent 2a4423d commit 77098e9

File tree

11 files changed

+193
-131
lines changed

11 files changed

+193
-131
lines changed

docs/changelogs/v3.6.1.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
## Change Log
22

3-
## Feature
3+
### Feature
44

55
* Implement more cooldown types for WSD learning rate scheduler. (#382, #386)
66
* Implement `AdamWSN` optimizer. (#387, #389)
77
* [Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees](https://arxiv.org/abs/2411.07120)
88
* Implement `AdamC` optimizer. (#388, #390)
99
* [Why Gradients Rapidly Increase Near the End of Training](https://arxiv.org/abs/2506.02285)
1010

11+
### Update
12+
13+
* Change the default range of the `beta` parameter from `[0, 1]` to `[0, 1)`. (#392)
14+
1115
### Fix
1216

1317
* Fix to use `momentum buffer` instead of the gradient to calculate LMO. (#385)

examples/visualize_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
filterwarnings('ignore', category=UserWarning)
1818

19-
OPTIMIZERS_IGNORE = ('lomo', 'adalomo', 'demo', 'a2grad', 'muon', 'alice')
19+
OPTIMIZERS_IGNORE = ('lomo', 'adalomo', 'demo', 'a2grad', 'muon', 'alice', 'adamc', 'adamwsn')
2020
OPTIMIZERS_MODEL_INPUT_NEEDED = ('lomo', 'adalomo', 'adammini')
2121
OPTIMIZERS_GRAPH_NEEDED = ('adahessian', 'sophiah')
2222
OPTIMIZERS_CLOSURE_NEEDED = ('alig', 'bsam')

poetry.lock

Lines changed: 167 additions & 112 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "3.6.0"
3+
version = "3.6.1"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -102,6 +102,7 @@ ignore = [
102102
"D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413",
103103
"PLR0912", "PLR0913", "PLR0915", "PLR2004", "PLW2901",
104104
"Q003", "ARG002",
105+
"RUF028",
105106
]
106107
fixable = ["ALL"]
107108
unfixable = ["F401"]

pytorch_optimizer/base/optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,17 +358,17 @@ def validate_mod(x: int, y: int) -> None:
358358
if x % y != 0:
359359
raise ValueError(f'[-] {x} must be divisible by {y}')
360360

361-
def validate_betas(self, betas: BETAS) -> None:
361+
def validate_betas(self, betas: BETAS, beta_range_type: str = '[)', beta3_range_type: str = '[]') -> None:
362362
if betas[0] is not None:
363-
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
363+
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type=beta_range_type)
364364

365-
self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type='[]')
365+
self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type=beta_range_type)
366366

367367
if len(betas) < 3:
368368
return
369369

370370
if betas[2] is not None:
371-
self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type='[]')
371+
self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type=beta3_range_type)
372372

373373
def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
374374
if isinstance(nus, float):

pytorch_optimizer/loss/bi_tempered.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def compute_normalization_binary_search(activations: torch.Tensor, t: float, num
5050
activations.dtype
5151
)
5252

53-
shape_partition: Tuple[int, ...] = activations.shape[:-1] + (1,)
53+
shape_partition: Tuple[int, ...] = (*activations.shape[:-1], 1)
5454

5555
lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
5656
upper = -log_t(1.0 / effective_dim, t) * torch.ones_like(lower)

pytorch_optimizer/optimizer/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112

113113
def load_bnb_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover # noqa: PLR0911
114114
r"""Load bnb optimizer instance."""
115-
from bitsandbytes import optim
115+
from bitsandbytes import optim # noqa: PLC0415
116116

117117
if 'sgd8bit' in optimizer:
118118
return optim.SGD8bit
@@ -168,7 +168,7 @@ def load_bnb_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover # noqa
168168

169169
def load_q_galore_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
170170
r"""Load Q-GaLore optimizer instance."""
171-
import q_galore_torch
171+
import q_galore_torch # noqa: PLC0415
172172

173173
if 'adamw8bit' in optimizer:
174174
return q_galore_torch.QGaLoreAdamW8bit
@@ -178,7 +178,7 @@ def load_q_galore_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
178178

179179
def load_ao_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
180180
r"""Load TorchAO optimizer instance."""
181-
from torchao.prototype import low_bit_optim
181+
from torchao.prototype import low_bit_optim # noqa: PLC0415
182182

183183
if 'adamw8bit' in optimizer:
184184
return low_bit_optim.AdamW8bit

pytorch_optimizer/optimizer/pnm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
**kwargs,
3434
):
3535
self.validate_learning_rate(lr)
36-
self.validate_betas(betas)
36+
self.validate_betas(betas, beta_range_type='[]')
3737
self.validate_non_negative(weight_decay, 'weight_decay')
3838
self.validate_non_negative(eps, 'eps')
3939

requirements-dev.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ black==25.1.0 ; python_version >= "3.9"
55
click==8.1.8 ; python_version >= "3.8"
66
colorama==0.4.6 ; python_version >= "3.8" and (sys_platform == "win32" or platform_system == "Windows")
77
coverage[toml]==7.6.1 ; python_version == "3.8"
8-
coverage[toml]==7.8.2 ; python_version >= "3.9"
8+
coverage[toml]==7.9.2 ; python_version >= "3.9"
99
exceptiongroup==1.3.0 ; python_version < "3.11" and python_version >= "3.8"
1010
filelock==3.16.1 ; python_version == "3.8"
1111
filelock==3.18.0 ; python_version >= "3.9"
@@ -31,11 +31,12 @@ pluggy==1.5.0 ; python_version == "3.8"
3131
pluggy==1.6.0 ; python_version >= "3.9"
3232
pytest-cov==5.0.0 ; python_version >= "3.8"
3333
pytest==8.3.5 ; python_version >= "3.8"
34-
ruff==0.11.12 ; python_version >= "3.8"
34+
ruff==0.12.2 ; python_version >= "3.8"
3535
setuptools==80.9.0 ; python_version >= "3.12"
3636
sympy==1.13.3 ; python_version == "3.8"
3737
sympy==1.14.0 ; python_version >= "3.9"
3838
tomli==2.2.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
3939
torch==2.4.1+cpu ; python_version == "3.8"
40-
torch==2.7.0+cpu ; python_version >= "3.9"
41-
typing-extensions==4.13.2 ; python_version >= "3.8"
40+
torch==2.7.1+cpu ; python_version >= "3.9"
41+
typing-extensions==4.13.2 ; python_version == "3.8"
42+
typing-extensions==4.14.1 ; python_version >= "3.9"

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ setuptools==80.9.0 ; python_version >= "3.12"
1616
sympy==1.13.3 ; python_version == "3.8"
1717
sympy==1.14.0 ; python_version >= "3.9"
1818
torch==2.4.1+cpu ; python_version == "3.8"
19-
torch==2.7.0+cpu ; python_version >= "3.9"
20-
typing-extensions==4.13.2 ; python_version >= "3.8"
19+
torch==2.7.1+cpu ; python_version >= "3.9"
20+
typing-extensions==4.13.2 ; python_version == "3.8"
21+
typing-extensions==4.14.1 ; python_version >= "3.9"

0 commit comments

Comments
 (0)