|
22 | 22 |
|
23 | 23 |
|
24 | 24 | class OrthogonalizedOptimizerTest(parameterized.TestCase): |
| 25 | + @parameterized.product( |
| 26 | + use_independent_wd=[True, False], |
| 27 | + use_decoupled_wd=[True, False], |
| 28 | + shape=[(5, 7), (33, 65), (127, 257)], |
| 29 | + use_nesterov=[True, False], |
| 30 | + fp32_matmul_prec=["highest", "medium", "low"], |
| 31 | + ) |
| 32 | + def test_smoke(self, use_independent_wd, use_decoupled_wd, shape, use_nesterov, fp32_matmul_prec) -> None: |
| 33 | + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) |
| 34 | + test_param.grad = torch.randint_like(test_param, -5, 5) |
| 35 | + |
| 36 | + orthogonalized_opt = OrthogonalizedOptimizer( |
| 37 | + [test_param], |
| 38 | + lr=2, |
| 39 | + momentum_beta=0, |
| 40 | + weight_decay=0.5, |
| 41 | + use_nesterov=use_nesterov, |
| 42 | + use_decoupled_wd=use_decoupled_wd, |
| 43 | + use_independent_wd=use_independent_wd, |
| 44 | + fp32_matmul_prec=fp32_matmul_prec, |
| 45 | + ) |
| 46 | + orthogonalized_opt.step() |
| 47 | + |
25 | 48 | @parameterized.parameters( |
26 | 49 | {"shape": (5, 7)}, |
27 | 50 | {"shape": (33, 65)}, |
@@ -195,23 +218,22 @@ def test_use_independent_wd(self) -> None: |
195 | 218 |
|
196 | 219 | # Test with independent weight decay: with lr=0, weight decay should still be applied |
197 | 220 | # With lr=0, no gradient update occurs, so param should be exactly (1-wd)*param |
198 | | - indep_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) |
199 | | - indep_param_initial = indep_param.data.clone() |
200 | | - indep_param.grad = torch.randint_like(indep_param, -5, 5) |
| 221 | + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) |
| 222 | + test_param.grad = torch.randint_like(test_param, -5, 5) |
| 223 | + # With independent weight decay and lr=0, param should be exactly (1-wd)*param |
| 224 | + expected_param = (1 - weight_decay) * test_param.data |
201 | 225 |
|
202 | 226 | muon_opt_indep = muon.Muon( |
203 | | - [indep_param], |
| 227 | + [test_param], |
204 | 228 | lr=0.0, # Zero learning rate |
205 | 229 | weight_decay=weight_decay, |
206 | 230 | use_independent_wd=True, |
207 | 231 | momentum_beta=0.0, |
208 | 232 | ) |
209 | 233 | muon_opt_indep.step() |
210 | 234 |
|
211 | | - # With independent weight decay and lr=0, param should be exactly (1-wd)*param |
212 | | - expected_param = (1 - weight_decay) * indep_param_initial |
213 | 235 | torch.testing.assert_close( |
214 | | - indep_param.data, |
| 236 | + test_param, |
215 | 237 | expected_param, |
216 | 238 | atol=0, |
217 | 239 | rtol=0, |
|
0 commit comments