|
7 | 7 | WSAM, |
8 | 8 | Lookahead, |
9 | 9 | LookSAM, |
| 10 | + BSAM, |
10 | 11 | PCGrad, |
11 | 12 | Ranger21, |
12 | 13 | SafeFP16Optimizer, |
13 | 14 | load_optimizer, |
14 | 15 | ) |
| 16 | +from pytorch_optimizer.base.exception import NoClosureError |
15 | 17 | from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector |
16 | 18 | from tests.constants import PULLBACK_MOMENTUM |
17 | 19 | from tests.utils import Example, simple_parameter |
@@ -86,16 +88,21 @@ def test_lookahead_parameters(): |
86 | 88 | Lookahead(optimizer, pullback_momentum='invalid') |
87 | 89 |
|
88 | 90 |
|
89 | | -@pytest.mark.parametrize('optimizer', [SAM, WSAM, LookSAM]) |
| 91 | +@pytest.mark.parametrize('optimizer', [SAM, WSAM, LookSAM, BSAM]) |
90 | 92 | def test_sam_family_methods(optimizer): |
91 | 93 | base_optimizer = load_optimizer('lion') |
92 | 94 |
|
93 | | - opt = optimizer(params=[simple_parameter()], model=None, base_optimizer=base_optimizer) |
94 | | - opt.init_group({}) |
| 95 | + opt = optimizer(params=[simple_parameter()], model=None, base_optimizer=base_optimizer, num_data=1) |
| 96 | + opt.zero_grad() |
| 97 | + |
| 98 | + opt.init_group({'params': []}) |
95 | 99 | opt.load_state_dict(opt.state_dict()) |
96 | 100 |
|
| 101 | + with pytest.raises(NoClosureError): |
| 102 | + opt.step() |
| 103 | + |
97 | 104 | with pytest.raises(ValueError): |
98 | | - optimizer(model=None, params=None, base_optimizer=base_optimizer, rho=-0.1) |
| 105 | + optimizer(model=None, params=None, base_optimizer=base_optimizer, rho=-0.1, num_data=1) |
99 | 106 |
|
100 | 107 |
|
101 | 108 | def test_safe_fp16_methods(): |
|
0 commit comments