Skip to content

Commit f327405

Browse files
committed
update: split tests
1 parent 7d707b8 commit f327405

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

tests/test_load_optimizers.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,3 @@ def test_load_optimizers_valid(valid_optimizer_names):
3838
def test_load_optimizers_invalid(invalid_optimizer_names):
3939
with pytest.raises(NotImplementedError):
4040
load_optimizers(invalid_optimizer_names)
41-
42-
43-
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
44-
def test_learning_rate(optimizer_names):
45-
with pytest.raises(ValueError):
46-
optimizer = load_optimizers(optimizer_names)
47-
optimizer(None, lr=-1e-2)
48-
49-
50-
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
51-
def test_epsilon(optimizer_names):
52-
with pytest.raises(ValueError):
53-
optimizer = load_optimizers(optimizer_names)
54-
optimizer(None, eps=-1e-6)
55-
56-
57-
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
58-
def test_weight_decay(optimizer_names):
59-
with pytest.raises(ValueError):
60-
optimizer = load_optimizers(optimizer_names)
61-
optimizer(None, weight_decay=-1e-3)

tests/test_optimizer_parameters.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from pytorch_optimizer import load_optimizers
6+
7+
VALID_OPTIMIZER_NAMES: List[str] = [
8+
'adamp',
9+
'sgdp',
10+
'madgrad',
11+
'ranger',
12+
'ranger21',
13+
'radam',
14+
'adabound',
15+
'adahessian',
16+
'adabelief',
17+
'diffgrad',
18+
'diffrgrad',
19+
'lamb',
20+
]
21+
22+
23+
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
24+
def test_learning_rate(optimizer_names):
25+
with pytest.raises(ValueError):
26+
optimizer = load_optimizers(optimizer_names)
27+
optimizer(None, lr=-1e-2)
28+
29+
30+
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
31+
def test_epsilon(optimizer_names):
32+
with pytest.raises(ValueError):
33+
optimizer = load_optimizers(optimizer_names)
34+
optimizer(None, eps=-1e-6)
35+
36+
37+
@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES)
38+
def test_weight_decay(optimizer_names):
39+
with pytest.raises(ValueError):
40+
optimizer = load_optimizers(optimizer_names)
41+
optimizer(None, weight_decay=-1e-3)

0 commit comments

Comments
 (0)