Skip to content

Commit d32db05

Browse files
committed
refactor: test_beta, test_betas
1 parent 5404b44 commit d32db05

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

tests/test_general_optimizer_parameters.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,18 @@ def test_lookahead_k(optimizer_name):
107107
optimizer(None, k=-1)
108108

109109

110-
@pytest.mark.parametrize('optimizer_name', ['ranger21'])
111-
def test_beta0(optimizer_name):
112-
optimizer = load_optimizer(optimizer_name)
113-
with pytest.raises(ValueError):
114-
optimizer(None, num_iterations=200, beta0=-0.1)
115-
116-
117-
@pytest.mark.parametrize('optimizer_name', ['nero', 'apollo', 'sm3', 'msvag'])
110+
@pytest.mark.parametrize('optimizer_name', ['nero', 'apollo', 'sm3', 'msvag', 'ranger21'])
118111
def test_beta(optimizer_name):
119112
optimizer = load_optimizer(optimizer_name)
120-
with pytest.raises(ValueError):
121-
optimizer(None, beta=-0.1)
113+
114+
if optimizer_name == 'ranger21':
115+
# test beta0
116+
with pytest.raises(ValueError):
117+
optimizer(None, num_iterations=200, beta0=-0.1)
118+
else:
119+
# test beta
120+
with pytest.raises(ValueError):
121+
optimizer(None, beta=-0.1)
122122

123123

124124
@pytest.mark.parametrize('optimizer_name', BETA_OPTIMIZER_NAMES)
@@ -137,6 +137,9 @@ def test_betas(optimizer_name):
137137

138138
with pytest.raises(ValueError):
139139
optimizer(None, **config2)
140+
elif optimizer_name == 'prodigy':
141+
with pytest.raises(ValueError):
142+
optimizer(None, beta3=-0.1)
140143
else:
141144
with pytest.raises(ValueError):
142145
optimizer(None, betas=(0.1, 0.1, -0.1))

0 commit comments

Comments
 (0)