Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions emerging_optimizers/orthogonalized_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(
momentum_beta: float = 0.95,
use_nesterov: bool = False,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
use_decoupled_wd: bool = True,
use_independent_wd: bool = False,
fp32_matmul_prec: str = "medium",
coefficient_type: str = "quintic",
num_ns_steps: int = 5,
Expand Down Expand Up @@ -106,7 +107,8 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
momentum_beta,
use_nesterov,
weight_decay,
use_decoupled_weight_decay,
use_decoupled_wd,
use_independent_wd,
fp32_matmul_prec,
scaled_orthogonalize_fn,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
use_nesterov: Whether to use Nesterov-style momentum in the internal SGD.
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
use_decoupled_wd: Whether to use decoupled weight decay, default to be True.
use_independent_wd: Whether to use independent weight decay (https://arxiv.org/abs/2510.19093),
default to be False.
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
"""

Expand Down Expand Up @@ -99,7 +101,8 @@ def __init__(
momentum_beta: float,
use_nesterov: bool,
weight_decay: float,
use_decoupled_weight_decay: bool,
use_decoupled_wd: bool,
use_independent_wd: bool,
fp32_matmul_prec: str,
scaled_orthogonalize_fn: Callable | None = None,
**kwargs: Any,
Expand All @@ -114,7 +117,8 @@ def __init__(
momentum_beta=momentum_beta,
use_nesterov=use_nesterov,
weight_decay=weight_decay,
use_decoupled_weight_decay=use_decoupled_weight_decay,
use_decoupled_wd=use_decoupled_wd,
use_independent_wd=use_independent_wd,
**kwargs,
)

Expand Down Expand Up @@ -152,9 +156,14 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:

# Apply weight decay
if group["weight_decay"] > 0.0:
if group["use_decoupled_weight_decay"]:
# Apply decoupled weight decay
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
if group["use_decoupled_wd"]:
# Apply weight decay directly to params without changing gradients
if group["use_independent_wd"]:
# do not tie weight decay and learning rate
weight_decay_scale = group["weight_decay"]
else:
weight_decay_scale = group["weight_decay"] * group["lr"]
p.add_(p, alpha=(-weight_decay_scale))
else:
# add l2 regularization before preconditioning (i.e. adding a squared loss term)
grad += group["weight_decay"] * p
Expand Down
38 changes: 35 additions & 3 deletions tests/test_orthogonalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None:
momentum_beta=0,
use_nesterov=False,
weight_decay=0.5,
use_decoupled_weight_decay=True,
use_decoupled_wd=True,
use_independent_wd=False,
fp32_matmul_prec="highest",
)

Expand Down Expand Up @@ -83,7 +84,8 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) ->
momentum_beta=0.5,
use_nesterov=False,
weight_decay=0.0,
use_decoupled_weight_decay=False,
use_decoupled_wd=False,
use_independent_wd=False,
fp32_matmul_prec="highest",
)

Expand Down Expand Up @@ -133,7 +135,8 @@ def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor:
momentum_beta=0,
use_nesterov=False,
weight_decay=0.0,
use_decoupled_weight_decay=False,
use_decoupled_wd=False,
use_independent_wd=False,
fp32_matmul_prec="highest",
scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn,
)
Expand Down Expand Up @@ -185,6 +188,35 @@ def test_use_syrk_match_without_syrk(self) -> None:
ref_param.data,
)

def test_use_independent_wd(self) -> None:
"""Test that use_independent_wd properly decouples weight decay from learning rate."""
shape = (32, 32)
weight_decay = 0.25

# Test with independent weight decay: with lr=0, weight decay should still be applied
# With lr=0, no gradient update occurs, so param should be exactly (1-wd)*param
indep_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
indep_param_initial = indep_param.data.clone()
indep_param.grad = torch.randint_like(indep_param, -5, 5)

muon_opt_indep = muon.Muon(
[indep_param],
lr=0.0, # Zero learning rate
weight_decay=weight_decay,
use_independent_wd=True,
momentum_beta=0.0,
)
muon_opt_indep.step()

# With independent weight decay and lr=0, param should be exactly (1-wd)*param
expected_param = (1 - weight_decay) * indep_param_initial
torch.testing.assert_close(
indep_param.data,
expected_param,
atol=0,
rtol=0,
)


if __name__ == "__main__":
absltest.main()