Skip to content

Commit 19d8201

Browse files
Independent wd for orthogonalized optimizer (#66)
* added independent wd Signed-off-by: mikail <[email protected]>
1 parent bc13a6b commit 19d8201

File tree

3 files changed

+54
-11
lines changed

3 files changed

+54
-11
lines changed

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def __init__(
6666
momentum_beta: float = 0.95,
6767
use_nesterov: bool = False,
6868
weight_decay: float = 0.01,
69-
use_decoupled_weight_decay: bool = True,
69+
use_decoupled_wd: bool = True,
70+
use_independent_wd: bool = False,
7071
fp32_matmul_prec: str = "medium",
7172
coefficient_type: str = "quintic",
7273
num_ns_steps: int = 5,
@@ -106,7 +107,8 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
106107
momentum_beta,
107108
use_nesterov,
108109
weight_decay,
109-
use_decoupled_weight_decay,
110+
use_decoupled_wd,
111+
use_independent_wd,
110112
fp32_matmul_prec,
111113
scaled_orthogonalize_fn,
112114
)

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
use_nesterov: Whether to use Nesterov-style momentum in the internal SGD.
3636
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
3737
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
38-
use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
38+
use_decoupled_wd: Whether to use decoupled weight decay, default to be True.
39+
use_independent_wd: Whether to use independent weight decay (https://arxiv.org/abs/2510.19093),
40+
default to be False.
3941
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
4042
"""
4143

@@ -99,7 +101,8 @@ def __init__(
99101
momentum_beta: float,
100102
use_nesterov: bool,
101103
weight_decay: float,
102-
use_decoupled_weight_decay: bool,
104+
use_decoupled_wd: bool,
105+
use_independent_wd: bool,
103106
fp32_matmul_prec: str,
104107
scaled_orthogonalize_fn: Callable | None = None,
105108
**kwargs: Any,
@@ -114,7 +117,8 @@ def __init__(
114117
momentum_beta=momentum_beta,
115118
use_nesterov=use_nesterov,
116119
weight_decay=weight_decay,
117-
use_decoupled_weight_decay=use_decoupled_weight_decay,
120+
use_decoupled_wd=use_decoupled_wd,
121+
use_independent_wd=use_independent_wd,
118122
**kwargs,
119123
)
120124

@@ -152,9 +156,14 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
152156

153157
# Apply weight decay
154158
if group["weight_decay"] > 0.0:
155-
if group["use_decoupled_weight_decay"]:
156-
# Apply decoupled weight decay
157-
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
159+
if group["use_decoupled_wd"]:
160+
# Apply weight decay directly to params without changing gradients
161+
if group["use_independent_wd"]:
162+
# do not tie weight decay and learning rate
163+
weight_decay_scale = group["weight_decay"]
164+
else:
165+
weight_decay_scale = group["weight_decay"] * group["lr"]
166+
p.add_(p, alpha=(-weight_decay_scale))
158167
else:
159168
# add l2 regularization before preconditioning (i.e. adding a squared loss term)
160169
grad += group["weight_decay"] * p

tests/test_orthogonalized_optimizer.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None:
4242
momentum_beta=0,
4343
use_nesterov=False,
4444
weight_decay=0.5,
45-
use_decoupled_weight_decay=True,
45+
use_decoupled_wd=True,
46+
use_independent_wd=False,
4647
fp32_matmul_prec="highest",
4748
)
4849

@@ -83,7 +84,8 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) ->
8384
momentum_beta=0.5,
8485
use_nesterov=False,
8586
weight_decay=0.0,
86-
use_decoupled_weight_decay=False,
87+
use_decoupled_wd=False,
88+
use_independent_wd=False,
8789
fp32_matmul_prec="highest",
8890
)
8991

@@ -133,7 +135,8 @@ def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor:
133135
momentum_beta=0,
134136
use_nesterov=False,
135137
weight_decay=0.0,
136-
use_decoupled_weight_decay=False,
138+
use_decoupled_wd=False,
139+
use_independent_wd=False,
137140
fp32_matmul_prec="highest",
138141
scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn,
139142
)
@@ -185,6 +188,35 @@ def test_use_syrk_match_without_syrk(self) -> None:
185188
ref_param.data,
186189
)
187190

191+
def test_use_independent_wd(self) -> None:
192+
"""Test that use_independent_wd properly decouples weight decay from learning rate."""
193+
shape = (32, 32)
194+
weight_decay = 0.25
195+
196+
# Test with independent weight decay: with lr=0, weight decay should still be applied
197+
# With lr=0, no gradient update occurs, so param should be exactly (1-wd)*param
198+
indep_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
199+
indep_param_initial = indep_param.data.clone()
200+
indep_param.grad = torch.randint_like(indep_param, -5, 5)
201+
202+
muon_opt_indep = muon.Muon(
203+
[indep_param],
204+
lr=0.0, # Zero learning rate
205+
weight_decay=weight_decay,
206+
use_independent_wd=True,
207+
momentum_beta=0.0,
208+
)
209+
muon_opt_indep.step()
210+
211+
# With independent weight decay and lr=0, param should be exactly (1-wd)*param
212+
expected_param = (1 - weight_decay) * indep_param_initial
213+
torch.testing.assert_close(
214+
indep_param.data,
215+
expected_param,
216+
atol=0,
217+
rtol=0,
218+
)
219+
188220

189221
if __name__ == "__main__":
190222
absltest.main()

0 commit comments

Comments
 (0)