Skip to content

Commit 5fdc6af

Browse files
committed
Update to geoopt 0.5.1; add support for Python 3.12+
1 parent 65112cd commit 5fdc6af

File tree

3 files changed

+27
-29
lines changed

3 files changed

+27
-29
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
fail-fast: false # don’t stop the matrix if one Python version fails
1414
matrix:
15-
python-version: ["3.10", "3.11"] # jaxtyping requires >= 3.10; scipy requires < 3.12
15+
python-version: ["3.10", "3.11", "3.12", "3.13"] # jaxtyping requires >= 3.10
1616

1717
steps:
1818
# Setup and installation

manify/optimizers/_adan.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323
import torch
2424

25-
if TYPE_CHECKING:
26-
from beartype.typing import List
2725
from torch import Tensor
2826
from torch.optim.optimizer import Optimizer
2927

@@ -242,12 +240,12 @@ def step(self, closure=None):
242240

243241

244242
def _single_tensor_adan(
245-
params: List[Tensor],
246-
grads: List[Tensor],
247-
exp_avgs: List[Tensor],
248-
exp_avg_sqs: List[Tensor],
249-
exp_avg_diffs: List[Tensor],
250-
neg_pre_grads: List[Tensor],
243+
params: list[Tensor],
244+
grads: list[Tensor],
245+
exp_avgs: list[Tensor],
246+
exp_avg_sqs: list[Tensor],
247+
exp_avg_diffs: list[Tensor],
248+
neg_pre_grads: list[Tensor],
251249
*,
252250
beta1: float,
253251
beta2: float,
@@ -297,12 +295,12 @@ def _single_tensor_adan(
297295

298296

299297
def _multi_tensor_adan(
300-
params: List[Tensor],
301-
grads: List[Tensor],
302-
exp_avgs: List[Tensor],
303-
exp_avg_sqs: List[Tensor],
304-
exp_avg_diffs: List[Tensor],
305-
neg_pre_grads: List[Tensor],
298+
params: list[Tensor],
299+
grads: list[Tensor],
300+
exp_avgs: list[Tensor],
301+
exp_avg_sqs: list[Tensor],
302+
exp_avg_diffs: list[Tensor],
303+
neg_pre_grads: list[Tensor],
306304
*,
307305
beta1: float,
308306
beta2: float,
@@ -356,12 +354,12 @@ def _multi_tensor_adan(
356354

357355

358356
def _fused_adan_multi_tensor(
359-
params: List[Tensor],
360-
grads: List[Tensor],
361-
exp_avgs: List[Tensor],
362-
exp_avg_sqs: List[Tensor],
363-
exp_avg_diffs: List[Tensor],
364-
neg_pre_grads: List[Tensor],
357+
params: list[Tensor],
358+
grads: list[Tensor],
359+
exp_avgs: list[Tensor],
360+
exp_avg_sqs: list[Tensor],
361+
exp_avg_diffs: list[Tensor],
362+
neg_pre_grads: list[Tensor],
365363
*,
366364
beta1: float,
367365
beta2: float,
@@ -400,12 +398,12 @@ def _fused_adan_multi_tensor(
400398

401399

402400
def _fused_adan_single_tensor(
403-
params: List[Tensor],
404-
grads: List[Tensor],
405-
exp_avgs: List[Tensor],
406-
exp_avg_sqs: List[Tensor],
407-
exp_avg_diffs: List[Tensor],
408-
neg_pre_grads: List[Tensor],
401+
params: list[Tensor],
402+
grads: list[Tensor],
403+
exp_avgs: list[Tensor],
404+
exp_avg_sqs: list[Tensor],
405+
exp_avg_diffs: list[Tensor],
406+
neg_pre_grads: list[Tensor],
409407
*,
410408
beta1: float,
411409
beta2: float,

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ authors = [
1212
]
1313
dependencies = [
1414
"torch",
15-
"geoopt",
15+
"geoopt>=0.5.1",
1616
"numpy",
1717
"tqdm",
1818
"cvxpy",
1919
"scikit-learn==1.5.1",
20-
"scipy<=1.9.3", # Ensure scalar_search_wolfe2 import works correctly
20+
"geoopt",
2121
"datasets"
2222
]
2323

0 commit comments

Comments
 (0)