|
40 | 40 | ] |
41 | 41 |
|
42 | 42 |
|
43 | | -@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES + ['nero']) |
44 | | -def test_learning_rate(optimizer_names): |
| 43 | +@pytest.mark.parametrize('optimizer_name', OPTIMIZER_NAMES + ['nero']) |
| 44 | +def test_learning_rate(optimizer_name): |
| 45 | + optimizer = load_optimizer(optimizer_name) |
| 46 | + |
45 | 47 | with pytest.raises(ValueError): |
46 | | - optimizer = load_optimizer(optimizer_names) |
47 | | - optimizer(None, lr=-1e-2) |
| 48 | + if optimizer_name == 'ranger21': |
| 49 | + optimizer(None, num_iterations=100, lr=-1e-2) |
| 50 | + else: |
| 51 | + optimizer(None, lr=-1e-2) |
| 52 | + |
48 | 53 |
|
| 54 | +@pytest.mark.parametrize('optimizer_name', OPTIMIZER_NAMES) |
| 55 | +def test_epsilon(optimizer_name): |
| 56 | + optimizer = load_optimizer(optimizer_name) |
49 | 57 |
|
50 | | -@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES) |
51 | | -def test_epsilon(optimizer_names): |
52 | 58 | with pytest.raises(ValueError): |
53 | | - optimizer = load_optimizer(optimizer_names) |
54 | | - optimizer(None, eps=-1e-6) |
| 59 | + if optimizer_name == 'ranger21': |
| 60 | + optimizer(None, num_iterations=100, eps=-1e-6) |
| 61 | + else: |
| 62 | + optimizer(None, eps=-1e-6) |
55 | 63 |
|
56 | 64 |
|
57 | | -@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES) |
58 | | -def test_weight_decay(optimizer_names): |
| 65 | +@pytest.mark.parametrize('optimizer_name', OPTIMIZER_NAMES) |
| 66 | +def test_weight_decay(optimizer_name): |
| 67 | + optimizer = load_optimizer(optimizer_name) |
| 68 | + |
59 | 69 | with pytest.raises(ValueError): |
60 | | - optimizer = load_optimizer(optimizer_names) |
61 | | - optimizer(None, weight_decay=-1e-3) |
| 70 | + if optimizer_name == 'ranger21': |
| 71 | + optimizer(None, num_iterations=100, weight_decay=-1e-3) |
| 72 | + else: |
| 73 | + optimizer(None, weight_decay=-1e-3) |
62 | 74 |
|
63 | 75 |
|
64 | | -@pytest.mark.parametrize('optimizer_names', ['adamp', 'sgdp']) |
65 | | -def test_wd_ratio(optimizer_names): |
| 76 | +@pytest.mark.parametrize('optimizer_name', ['adamp', 'sgdp']) |
| 77 | +def test_wd_ratio(optimizer_name): |
| 78 | + optimizer = load_optimizer(optimizer_name) |
66 | 79 | with pytest.raises(ValueError): |
67 | | - optimizer = load_optimizer(optimizer_names) |
68 | 80 | optimizer(None, wd_ratio=-1e-3) |
69 | 81 |
|
70 | 82 |
|
71 | | -@pytest.mark.parametrize('optimizer_names', ['lars']) |
72 | | -def test_trust_coefficient(optimizer_names): |
| 83 | +@pytest.mark.parametrize('optimizer_name', ['lars']) |
| 84 | +def test_trust_coefficient(optimizer_name): |
| 85 | + optimizer = load_optimizer(optimizer_name) |
73 | 86 | with pytest.raises(ValueError): |
74 | | - optimizer = load_optimizer(optimizer_names) |
75 | 87 | optimizer(None, trust_coefficient=-1e-3) |
76 | 88 |
|
77 | 89 |
|
78 | | -@pytest.mark.parametrize('optimizer_names', ['madgrad', 'lars']) |
79 | | -def test_momentum(optimizer_names): |
| 90 | +@pytest.mark.parametrize('optimizer_name', ['madgrad', 'lars']) |
| 91 | +def test_momentum(optimizer_name): |
| 92 | + optimizer = load_optimizer(optimizer_name) |
80 | 93 | with pytest.raises(ValueError): |
81 | | - optimizer = load_optimizer(optimizer_names) |
82 | 94 | optimizer(None, momentum=-1e-3) |
83 | 95 |
|
84 | 96 |
|
85 | | -@pytest.mark.parametrize('optimizer_names', ['ranger']) |
86 | | -def test_lookahead_k(optimizer_names): |
| 97 | +@pytest.mark.parametrize('optimizer_name', ['ranger']) |
| 98 | +def test_lookahead_k(optimizer_name): |
| 99 | + optimizer = load_optimizer(optimizer_name) |
87 | 100 | with pytest.raises(ValueError): |
88 | | - optimizer = load_optimizer(optimizer_names) |
89 | 101 | optimizer(None, k=-1) |
90 | 102 |
|
91 | 103 |
|
92 | | -@pytest.mark.parametrize('optimizer_names', ['ranger21']) |
93 | | -def test_beta0(optimizer_names): |
94 | | - optimizer = load_optimizer(optimizer_names) |
95 | | - |
| 104 | +@pytest.mark.parametrize('optimizer_name', ['ranger21']) |
| 105 | +def test_beta0(optimizer_name): |
| 106 | + optimizer = load_optimizer(optimizer_name) |
96 | 107 | with pytest.raises(ValueError): |
97 | 108 | optimizer(None, num_iterations=200, beta0=-0.1) |
98 | 109 |
|
99 | 110 |
|
100 | | -@pytest.mark.parametrize('optimizer_names', ['nero']) |
101 | | -def test_beta(optimizer_names): |
102 | | - optimizer = load_optimizer(optimizer_names) |
103 | | - |
| 111 | +@pytest.mark.parametrize('optimizer_name', ['nero']) |
| 112 | +def test_beta(optimizer_name): |
| 113 | + optimizer = load_optimizer(optimizer_name) |
104 | 114 | with pytest.raises(ValueError): |
105 | 115 | optimizer(None, beta=-0.1) |
106 | 116 |
|
107 | 117 |
|
108 | | -@pytest.mark.parametrize('optimizer_names', BETA_OPTIMIZER_NAMES) |
109 | | -def test_betas(optimizer_names): |
110 | | - optimizer = load_optimizer(optimizer_names) |
| 118 | +@pytest.mark.parametrize('optimizer_name', BETA_OPTIMIZER_NAMES) |
| 119 | +def test_betas(optimizer_name): |
| 120 | + optimizer = load_optimizer(optimizer_name) |
111 | 121 |
|
112 | 122 | with pytest.raises(ValueError): |
113 | | - optimizer(None, betas=(-0.1, 0.1)) |
| 123 | + if optimizer_name == 'ranger21': |
| 124 | + optimizer(None, num_iterations=100, betas=(-0.1, 0.1)) |
| 125 | + else: |
| 126 | + optimizer(None, betas=(-0.1, 0.1)) |
114 | 127 |
|
115 | 128 | with pytest.raises(ValueError): |
116 | | - optimizer(None, betas=(0.1, -0.1)) |
| 129 | + if optimizer_name == 'ranger21': |
| 130 | + optimizer(None, num_iterations=100, betas=(0.1, -0.1)) |
| 131 | + else: |
| 132 | + optimizer(None, betas=(0.1, -0.1)) |
117 | 133 |
|
118 | | - if optimizer_names == 'adapnm': |
| 134 | + if optimizer_name == 'adapnm': |
119 | 135 | with pytest.raises(ValueError): |
120 | 136 | optimizer(None, betas=(0.1, 0.1, -0.1)) |
121 | 137 |
|
122 | 138 |
|
123 | | -@pytest.mark.parametrize('optimizer_names', ['pcgrad']) |
124 | | -def test_reduction(optimizer_names): |
125 | | - model: nn.Module = Example() |
126 | | - parameters = model.parameters() |
| 139 | +def test_reduction(): |
| 140 | + parameters = Example().parameters() |
127 | 141 | optimizer = load_optimizer('adamp')(parameters) |
128 | 142 |
|
129 | 143 | with pytest.raises(ValueError): |
130 | 144 | PCGrad(optimizer, reduction='wrong') |
131 | 145 |
|
132 | 146 |
|
133 | | -@pytest.mark.parametrize('optimizer_names', ['shampoo']) |
134 | | -def test_update_frequency(optimizer_names): |
| 147 | +@pytest.mark.parametrize('optimizer_name', ['shampoo']) |
| 148 | +def test_update_frequency(optimizer_name): |
| 149 | + optimizer = load_optimizer(optimizer_name) |
135 | 150 | with pytest.raises(ValueError): |
136 | | - load_optimizer(optimizer_names)(None, update_freq=0) |
| 151 | + optimizer(None, update_freq=0) |
137 | 152 |
|
138 | 153 |
|
139 | 154 | def test_sam_parameters(): |
|
0 commit comments