Skip to content

Commit 476df4e

Browse files
committed
refactor: test_apollo_parameters
1 parent d32db05 commit 476df4e

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

tests/test_general_optimizer_parameters.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,6 @@ def test_weight_decay(optimizer_name):
7070
optimizer(None, **config)
7171

7272

73-
@pytest.mark.parametrize('optimizer_name', ['apollo'])
74-
def test_weight_decay_type(optimizer_name):
75-
optimizer = load_optimizer(optimizer_name)
76-
77-
with pytest.raises(ValueError):
78-
optimizer(None, weight_decay_type='dummy')
79-
80-
81-
@pytest.mark.parametrize('optimizer_name', ['apollo'])
82-
def test_rebound(optimizer_name):
83-
optimizer = load_optimizer(optimizer_name)
84-
85-
with pytest.raises(ValueError):
86-
optimizer(None, rebound='dummy')
87-
88-
8973
@pytest.mark.parametrize('optimizer_name', ['adamp', 'sgdp'])
9074
def test_wd_ratio(optimizer_name):
9175
optimizer = load_optimizer(optimizer_name)

tests/test_optimizer_parameters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,15 @@ def test_lars_parameters():
210210
# test trust_coefficient
211211
with pytest.raises(ValueError):
212212
opt(None, trust_coefficient=-1e-3)
213+
214+
215+
def test_apollo_parameters():
216+
opt = load_optimizer('apollo')
217+
218+
# test rebound type
219+
with pytest.raises(ValueError):
220+
opt(None, rebound='dummy')
221+
222+
# test weight_decay_type
223+
with pytest.raises(ValueError):
224+
opt(None, weight_decay_type='dummy')

0 commit comments

Comments
 (0)