|
4 | 4 |
|
5 | 5 | from pytorch_optimizer import load_optimizers |
6 | 6 |
|
7 | | -VALID_OPTIMIZER_NAMES: List[str] = [ |
| 7 | +OPTIMIZER_NAMES: List[str] = [ |
8 | 8 | 'adamp', |
9 | 9 | 'sgdp', |
10 | 10 | 'madgrad', |
|
19 | 19 | 'lamb', |
20 | 20 | ] |
21 | 21 |
|
| 22 | +BETA_OPTIMIZER_NAMES: List[str] = [ |
| 23 | + 'adabelief', |
| 24 | + 'adabound', |
| 25 | + 'adahessian', |
| 26 | + 'admap', |
| 27 | + 'diffgrad', |
| 28 | + 'diffrgrad', |
| 29 | + 'lamb', |
| 30 | + 'radam', |
| 31 | + 'ranger', |
| 32 | + 'ranger21', |
| 33 | +] |
| 34 | + |
22 | 35 |
|
23 | | -@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES) |
| 36 | +@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES) |
24 | 37 | def test_learning_rate(optimizer_names): |
25 | 38 | with pytest.raises(ValueError): |
26 | 39 | optimizer = load_optimizers(optimizer_names) |
27 | 40 | optimizer(None, lr=-1e-2) |
28 | 41 |
|
29 | 42 |
|
30 | | -@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES) |
| 43 | +@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES) |
31 | 44 | def test_epsilon(optimizer_names): |
32 | 45 | with pytest.raises(ValueError): |
33 | 46 | optimizer = load_optimizers(optimizer_names) |
34 | 47 | optimizer(None, eps=-1e-6) |
35 | 48 |
|
36 | 49 |
|
37 | | -@pytest.mark.parametrize('optimizer_names', VALID_OPTIMIZER_NAMES) |
| 50 | +@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES) |
38 | 51 | def test_weight_decay(optimizer_names): |
39 | 52 | with pytest.raises(ValueError): |
40 | 53 | optimizer = load_optimizers(optimizer_names) |
41 | 54 | optimizer(None, weight_decay=-1e-3) |
| 55 | + |
| 56 | + |
| 57 | +@pytest.mark.parametrize('optimizer_names', BETA_OPTIMIZER_NAMES) |
| 58 | +def test_betas(optimizer_names): |
| 59 | + with pytest.raises(ValueError): |
| 60 | + optimizer = load_optimizers(optimizer_names) |
| 61 | + optimizer(None, betas=(-0.1, 0.1)) |
| 62 | + |
| 63 | + with pytest.raises(ValueError): |
| 64 | + optimizer = load_optimizers(optimizer_names) |
| 65 | + optimizer(None, betas=(0.1, -0.1)) |
0 commit comments