Skip to content

Commit d017c6c

Browse files
committed
update: test_sam_family_methods
1 parent b4f9468 commit d017c6c

File tree

2 files changed

+11
-32
lines changed

2 files changed

+11
-32
lines changed

tests/test_optimizer_parameters.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
WSAM,
88
Lookahead,
99
LookSAM,
10+
BSAM,
1011
PCGrad,
1112
Ranger21,
1213
SafeFP16Optimizer,
1314
load_optimizer,
1415
)
16+
from pytorch_optimizer.base.exception import NoClosureError
1517
from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector
1618
from tests.constants import PULLBACK_MOMENTUM
1719
from tests.utils import Example, simple_parameter
@@ -86,16 +88,21 @@ def test_lookahead_parameters():
8688
Lookahead(optimizer, pullback_momentum='invalid')
8789

8890

89-
@pytest.mark.parametrize('optimizer', [SAM, WSAM, LookSAM])
91+
@pytest.mark.parametrize('optimizer', [SAM, WSAM, LookSAM, BSAM])
9092
def test_sam_family_methods(optimizer):
9193
base_optimizer = load_optimizer('lion')
9294

93-
opt = optimizer(params=[simple_parameter()], model=None, base_optimizer=base_optimizer)
94-
opt.init_group({})
95+
opt = optimizer(params=[simple_parameter()], model=None, base_optimizer=base_optimizer, num_data=1)
96+
opt.zero_grad()
97+
98+
opt.init_group({'params': []})
9599
opt.load_state_dict(opt.state_dict())
96100

101+
with pytest.raises(NoClosureError):
102+
opt.step()
103+
97104
with pytest.raises(ValueError):
98-
optimizer(model=None, params=None, base_optimizer=base_optimizer, rho=-0.1)
105+
optimizer(model=None, params=None, base_optimizer=base_optimizer, rho=-0.1, num_data=1)
99106

100107

101108
def test_safe_fp16_methods():

tests/test_optimizers.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -210,34 +210,6 @@ def test_closure(optimizer):
210210
optimizer.step(closure=dummy_closure)
211211

212212

213-
def test_no_closure():
214-
param = simple_parameter()
215-
216-
optimizer = SAM([param], load_optimizer('adamp'))
217-
optimizer.zero_grad()
218-
219-
with pytest.raises(NoClosureError):
220-
optimizer.step()
221-
222-
optimizer = WSAM(None, [param], load_optimizer('adamp'))
223-
optimizer.zero_grad()
224-
225-
with pytest.raises(NoClosureError):
226-
optimizer.step()
227-
228-
optimizer = BSAM([param], 1)
229-
optimizer.zero_grad()
230-
231-
with pytest.raises(NoClosureError):
232-
optimizer.step()
233-
234-
optimizer = LookSAM([param], load_optimizer('adamp'))
235-
optimizer.zero_grad()
236-
237-
with pytest.raises(NoClosureError):
238-
optimizer.step()
239-
240-
241213
def test_nero_zero_scale():
242214
param = simple_parameter()
243215

0 commit comments

Comments
 (0)