Skip to content

Commit 9308729

Browse files
committed
update: tests
1 parent 1d45233 commit 9308729

File tree

1 file changed

+50
-40
lines changed

1 file changed

+50
-40
lines changed

tests/test_optimizer_parameters.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,88 +40,97 @@
4040
]
4141

4242

43-
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES + ['nero'])
44-
def test_learning_rate(optimizer_names):
43+
@pytest.mark.parametrize('optimizer_name', OPTIMIZER_NAMES + ['nero'])
44+
def test_learning_rate(optimizer_name):
45+
optimizer = load_optimizer(optimizer_name)
46+
4547
with pytest.raises(ValueError):
46-
optimizer = load_optimizer(optimizer_names)
47-
optimizer(None, lr=-1e-2)
48+
if optimizer_name == 'ranger21':
49+
optimizer(None, num_iterations=100, lr=-1e-2)
50+
else:
51+
optimizer(None, lr=-1e-2)
52+
4853

54+
@pytest.mark.parametrize('optimizer_name', OPTIMIZER_NAMES)
55+
def test_epsilon(optimizer_name):
56+
optimizer = load_optimizer(optimizer_name)
4957

50-
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
51-
def test_epsilon(optimizer_names):
5258
with pytest.raises(ValueError):
53-
optimizer = load_optimizer(optimizer_names)
54-
optimizer(None, eps=-1e-6)
59+
if optimizer_name == 'ranger21':
60+
optimizer(None, num_iterations=100, eps=-1e-6)
61+
else:
62+
optimizer(None, eps=-1e-6)
63+
5564

65+
@pytest.mark.parametrize('optimizer_name', OPTIMIZER_NAMES)
66+
def test_weight_decay(optimizer_name):
67+
optimizer = load_optimizer(optimizer_name)
5668

57-
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
58-
def test_weight_decay(optimizer_names):
5969
with pytest.raises(ValueError):
60-
optimizer = load_optimizer(optimizer_names)
61-
optimizer(None, weight_decay=-1e-3)
70+
if optimizer_name == 'ranger21':
71+
optimizer(None, num_iterations=100, weight_decay=-1e-3)
72+
else:
73+
optimizer(None, weight_decay=-1e-3)
6274

6375

64-
@pytest.mark.parametrize('optimizer_names', ['adamp', 'sgdp'])
65-
def test_wd_ratio(optimizer_names):
76+
@pytest.mark.parametrize('optimizer_name', ['adamp', 'sgdp'])
77+
def test_wd_ratio(optimizer_name):
78+
optimizer = load_optimizer(optimizer_name)
6679
with pytest.raises(ValueError):
67-
optimizer = load_optimizer(optimizer_names)
6880
optimizer(None, wd_ratio=-1e-3)
6981

7082

71-
@pytest.mark.parametrize('optimizer_names', ['lars'])
72-
def test_trust_coefficient(optimizer_names):
83+
@pytest.mark.parametrize('optimizer_name', ['lars'])
84+
def test_trust_coefficient(optimizer_name):
85+
optimizer = load_optimizer(optimizer_name)
7386
with pytest.raises(ValueError):
74-
optimizer = load_optimizer(optimizer_names)
7587
optimizer(None, trust_coefficient=-1e-3)
7688

7789

78-
@pytest.mark.parametrize('optimizer_names', ['madgrad', 'lars'])
79-
def test_momentum(optimizer_names):
90+
@pytest.mark.parametrize('optimizer_name', ['madgrad', 'lars'])
91+
def test_momentum(optimizer_name):
92+
optimizer = load_optimizer(optimizer_name)
8093
with pytest.raises(ValueError):
81-
optimizer = load_optimizer(optimizer_names)
8294
optimizer(None, momentum=-1e-3)
8395

8496

85-
@pytest.mark.parametrize('optimizer_names', ['ranger'])
86-
def test_lookahead_k(optimizer_names):
97+
@pytest.mark.parametrize('optimizer_name', ['ranger'])
98+
def test_lookahead_k(optimizer_name):
99+
optimizer = load_optimizer(optimizer_name)
87100
with pytest.raises(ValueError):
88-
optimizer = load_optimizer(optimizer_names)
89101
optimizer(None, k=-1)
90102

91103

92-
@pytest.mark.parametrize('optimizer_names', ['ranger21'])
93-
def test_beta0(optimizer_names):
94-
optimizer = load_optimizer(optimizer_names)
95-
104+
@pytest.mark.parametrize('optimizer_name', ['ranger21'])
105+
def test_beta0(optimizer_name):
106+
optimizer = load_optimizer(optimizer_name)
96107
with pytest.raises(ValueError):
97108
optimizer(None, num_iterations=200, beta0=-0.1)
98109

99110

100111
@pytest.mark.parametrize('optimizer_names', ['nero'])
101-
def test_beta(optimizer_names):
102-
optimizer = load_optimizer(optimizer_names)
103-
112+
def test_beta(optimizer_name):
113+
optimizer = load_optimizer(optimizer_name)
104114
with pytest.raises(ValueError):
105115
optimizer(None, beta=-0.1)
106116

107117

108-
@pytest.mark.parametrize('optimizer_names', BETA_OPTIMIZER_NAMES)
109-
def test_betas(optimizer_names):
110-
optimizer = load_optimizer(optimizer_names)
118+
@pytest.mark.parametrize('optimizer_name', BETA_OPTIMIZER_NAMES)
119+
def test_betas(optimizer_name):
120+
optimizer = load_optimizer(optimizer_name)
111121

112122
with pytest.raises(ValueError):
113123
optimizer(None, betas=(-0.1, 0.1))
114124

115125
with pytest.raises(ValueError):
116126
optimizer(None, betas=(0.1, -0.1))
117127

118-
if optimizer_names == 'adapnm':
128+
if optimizer_name == 'adapnm':
119129
with pytest.raises(ValueError):
120130
optimizer(None, betas=(0.1, 0.1, -0.1))
121131

122132

123-
@pytest.mark.parametrize('optimizer_names', ['pcgrad'])
124-
def test_reduction(optimizer_names):
133+
def test_reduction():
125134
model: nn.Module = Example()
126135
parameters = model.parameters()
127136
optimizer = load_optimizer('adamp')(parameters)
@@ -130,10 +139,11 @@ def test_reduction(optimizer_names):
130139
PCGrad(optimizer, reduction='wrong')
131140

132141

133-
@pytest.mark.parametrize('optimizer_names', ['shampoo'])
134-
def test_update_frequency(optimizer_names):
142+
@pytest.mark.parametrize('optimizer_name', ['shampoo'])
143+
def test_update_frequency(optimizer_name):
144+
optimizer = load_optimizer(optimizer_name)
135145
with pytest.raises(ValueError):
136-
load_optimizer(optimizer_names)(None, update_freq=0)
146+
optimizer(None, update_freq=0)
137147

138148

139149
def test_sam_parameters():

0 commit comments

Comments
 (0)