Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
use_nesterov: bool = False,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
use_independent_weight_decay: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change: The variable name is getting too long. let's make them "use_decoupled_we", "use_independent_wd".

I don't think there are ambiguities around wd being short for weight decay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed name inside optim to use_decoupled_wd and use_independent_wd. The init call still uses the longer name, to not break megatron wrappers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Megatron can update, there will be a lot of update on dev anyway.
Nothing will break as dependency was made on commit not head of branch

fp32_matmul_prec: str = "medium",
coefficient_type: str = "quintic",
num_ns_steps: int = 5,
Expand Down Expand Up @@ -107,6 +108,7 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
use_nesterov,
weight_decay,
use_decoupled_weight_decay,
use_independent_weight_decay,
fp32_matmul_prec,
scaled_orthogonalize_fn,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
use_independent_weight_decay: Whether to use independent weight decay (https://arxiv.org/abs/2510.19093),
default to be False.
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
"""

Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
use_nesterov: bool,
weight_decay: float,
use_decoupled_weight_decay: bool,
use_independent_weight_decay: bool,
fp32_matmul_prec: str,
scaled_orthogonalize_fn: Callable | None = None,
**kwargs: Any,
Expand All @@ -115,6 +118,7 @@ def __init__(
use_nesterov=use_nesterov,
weight_decay=weight_decay,
use_decoupled_weight_decay=use_decoupled_weight_decay,
use_independent_weight_decay=use_independent_weight_decay,
**kwargs,
)

Expand Down Expand Up @@ -154,7 +158,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
if group["weight_decay"] > 0.0:
if group["use_decoupled_weight_decay"]:
# Apply decoupled weight decay
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
if group["use_independent_weight_decay"]:
# use independent weight decay
weight_decay_scale = group["weight_decay"]
else:
weight_decay_scale = group["weight_decay"] * group["lr"]
p.add_(p, alpha=(-weight_decay_scale))
else:
# add l2 regularization before preconditioning (i.e. adding a squared loss term)
grad += group["weight_decay"] * p
Expand Down
3 changes: 3 additions & 0 deletions tests/test_orthogonalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None:
use_nesterov=False,
weight_decay=0.5,
use_decoupled_weight_decay=True,
use_independent_weight_decay=False,
fp32_matmul_prec="highest",
)

Expand Down Expand Up @@ -84,6 +85,7 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) ->
use_nesterov=False,
weight_decay=0.0,
use_decoupled_weight_decay=False,
use_independent_weight_decay=False,
fp32_matmul_prec="highest",
)

Expand Down Expand Up @@ -134,6 +136,7 @@ def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor:
use_nesterov=False,
weight_decay=0.0,
use_decoupled_weight_decay=False,
use_independent_weight_decay=False,
fp32_matmul_prec="highest",
scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn,
)
Expand Down