Skip to content

Commit fb57fe5

Browse files
committed
update orth optimizer tests
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent 9008e43 commit fb57fe5

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,12 @@ source = ["emerging_optimizers/", "/workspace/emerging_optimizers"]
178178
[tool.coverage.report]
179179
exclude_lines = [
180180
"raise ValueError",
181-
"except ImportError"
181+
"except ImportError",
182182
]
183183
exclude_also = [
184-
"@triton"
184+
"@triton",
185+
".*sm_version",
186+
"if closure",
187+
"loss = closure"
185188
]
186189

tests/test_orthogonalized_optimizer.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,29 @@
2222

2323

2424
class OrthogonalizedOptimizerTest(parameterized.TestCase):
25+
@parameterized.product(
26+
use_independent_wd=[True, False],
27+
use_decoupled_wd=[True, False],
28+
shape=[(5, 7), (33, 65), (127, 257)],
29+
use_nesterov=[True, False],
30+
fp32_matmul_prec=["highest", "medium", "low"],
31+
)
32+
def test_smoke(self, use_independent_wd, use_decoupled_wd, shape, use_nesterov, fp32_matmul_prec) -> None:
33+
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
34+
test_param.grad = torch.randint_like(test_param, -5, 5)
35+
36+
orthogonalized_opt = OrthogonalizedOptimizer(
37+
[test_param],
38+
lr=2,
39+
momentum_beta=0,
40+
weight_decay=0.5,
41+
use_nesterov=use_nesterov,
42+
use_decoupled_wd=use_decoupled_wd,
43+
use_independent_wd=use_independent_wd,
44+
fp32_matmul_prec=fp32_matmul_prec,
45+
)
46+
orthogonalized_opt.step()
47+
2548
@parameterized.parameters(
2649
{"shape": (5, 7)},
2750
{"shape": (33, 65)},
@@ -195,23 +218,22 @@ def test_use_independent_wd(self) -> None:
195218

196219
# Test with independent weight decay: with lr=0, weight decay should still be applied
197220
# With lr=0, no gradient update occurs, so param should be exactly (1-wd)*param
198-
indep_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
199-
indep_param_initial = indep_param.data.clone()
200-
indep_param.grad = torch.randint_like(indep_param, -5, 5)
221+
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
222+
test_param.grad = torch.randint_like(test_param, -5, 5)
223+
# With independent weight decay and lr=0, param should be exactly (1-wd)*param
224+
expected_param = (1 - weight_decay) * test_param.data
201225

202226
muon_opt_indep = muon.Muon(
203-
[indep_param],
227+
[test_param],
204228
lr=0.0, # Zero learning rate
205229
weight_decay=weight_decay,
206230
use_independent_wd=True,
207231
momentum_beta=0.0,
208232
)
209233
muon_opt_indep.step()
210234

211-
# With independent weight decay and lr=0, param should be exactly (1-wd)*param
212-
expected_param = (1 - weight_decay) * indep_param_initial
213235
torch.testing.assert_close(
214-
indep_param.data,
236+
test_param,
215237
expected_param,
216238
atol=0,
217239
rtol=0,

0 commit comments

Comments
 (0)