Skip to content

Commit 01dc147

Browse files
committed
update: test_betas
1 parent 6fb3a4f commit 01dc147

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

tests/test_optimizer_parameters.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pytorch_optimizer import load_optimizers
66

7-
VALID_OPTIMIZER_NAMES: List[str] = [
7+
OPTIMIZER_NAMES: List[str] = [
88
'adamp',
99
'sgdp',
1010
'madgrad',
@@ -19,23 +19,47 @@
1919
'lamb',
2020
]
2121

22+
BETA_OPTIMIZER_NAMES: List[str] = [
23+
'adabelief',
24+
'adabound',
25+
'adahessian',
26+
'admap',
27+
'diffgrad',
28+
'diffrgrad',
29+
'lamb',
30+
'radam',
31+
'ranger',
32+
'ranger21',
33+
]
34+
2235

23-
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
36+
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
2437
def test_learning_rate(optimizer_names):
2538
with pytest.raises(ValueError):
2639
optimizer = load_optimizers(optimizer_names)
2740
optimizer(None, lr=-1e-2)
2841

2942

30-
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
43+
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
3144
def test_epsilon(optimizer_names):
3245
with pytest.raises(ValueError):
3346
optimizer = load_optimizers(optimizer_names)
3447
optimizer(None, eps=-1e-6)
3548

3649

37-
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
50+
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
3851
def test_weight_decay(optimizer_names):
3952
with pytest.raises(ValueError):
4053
optimizer = load_optimizers(optimizer_names)
4154
optimizer(None, weight_decay=-1e-3)
55+
56+
57+
@pytest.mark.parametrize('optimizer_names', BETA_OPTIMIZER_NAMES)
58+
def test_betas(optimizer_names):
59+
with pytest.raises(ValueError):
60+
optimizer = load_optimizers(optimizer_names)
61+
optimizer(None, betas=(-0.1, 0.1))
62+
63+
with pytest.raises(ValueError):
64+
optimizer = load_optimizers(optimizer_names)
65+
optimizer(None, betas=(0.1, -0.1))

0 commit comments

Comments
 (0)