Skip to content

Commit b4f9468

Browse files
committed
update: test cases
1 parent 057c28d commit b4f9468

File tree

2 files changed

+9
-24
lines changed

2 files changed

+9
-24
lines changed

tests/test_optimizer_parameters.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,6 @@ def test_pcgrad_parameters():
6262
PCGrad(opt, reduction='invalid')
6363

6464

65-
def test_sam_parameters():
66-
with pytest.raises(ValueError):
67-
SAM(None, load_optimizer('adamp'), rho=-0.1)
68-
69-
70-
def test_wsam_parameters():
71-
with pytest.raises(ValueError):
72-
WSAM(None, None, load_optimizer('adamp'), rho=-0.1)
73-
74-
7565
def test_lookahead_parameters():
7666
optimizer_instance = load_optimizer('adamp')
7767
optimizer = optimizer_instance([simple_parameter()])
@@ -96,22 +86,16 @@ def test_lookahead_parameters():
9686
Lookahead(optimizer, pullback_momentum='invalid')
9787

9888

99-
def test_sam_methods():
100-
optimizer = SAM([simple_parameter()], load_optimizer('adamp'))
101-
optimizer.init_group()
102-
optimizer.load_state_dict(optimizer.state_dict())
103-
104-
105-
def test_wsam_methods():
106-
optimizer = WSAM(None, [simple_parameter()], load_optimizer('adamp'))
107-
optimizer.init_group()
108-
optimizer.load_state_dict(optimizer.state_dict())
89+
@pytest.mark.parametrize('optimizer', [SAM, WSAM, LookSAM])
90+
def test_sam_family_methods(optimizer):
91+
base_optimizer = load_optimizer('lion')
10992

93+
opt = optimizer(params=[simple_parameter()], model=None, base_optimizer=base_optimizer)
94+
opt.init_group({})
95+
opt.load_state_dict(opt.state_dict())
11096

111-
def test_looksam_methods():
112-
optimizer = LookSAM([simple_parameter()], load_optimizer('adamp'))
113-
optimizer.init_group()
114-
optimizer.load_state_dict(optimizer.state_dict())
97+
with pytest.raises(ValueError):
98+
optimizer(model=None, params=None, base_optimizer=base_optimizer, rho=-0.1)
11599

116100

117101
def test_safe_fp16_methods():

tests/test_wrapper_optimizers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_looksam_optimizer(environment):
105105
for _ in range(5):
106106
loss = loss_fn(y_data, model(x_data))
107107
loss.backward()
108+
108109
optimizer.first_step(zero_grad=True)
109110

110111
loss_fn(y_data, model(x_data)).backward()

0 commit comments

Comments
 (0)