Skip to content

Commit da4e311

Browse files
committed
add a liberal for NS coefficient
Signed-off-by: Hao Wu <[email protected]>
1 parent 53bc522 commit da4e311

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from emerging_optimizers import mixin as opt_mixin
2828
from emerging_optimizers import utils
2929
from emerging_optimizers.orthogonalized_optimizers import muon
30+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
3031
from emerging_optimizers.utils import FP32MatmulPrecT
3132

3233

@@ -67,7 +68,7 @@ def __init__(
6768
use_nesterov: bool,
6869
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
6970
fp32_matmul_prec: FP32MatmulPrecT,
70-
coefficient_type: str = "quintic",
71+
coefficient_type: NSCoeffT = "quintic",
7172
num_ns_steps: int = 5,
7273
scale_mode: muon.MuonScaleT = "spectral",
7374
extra_scale_factor: float = 1.0,

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from emerging_optimizers import triton_kernels
2323
from emerging_optimizers.mixin import WeightDecayT
2424
from emerging_optimizers.orthogonalized_optimizers import muon_utils
25+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
2526
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
2627
from emerging_optimizers.utils import FP32MatmulPrecT
2728

@@ -77,7 +78,7 @@ def __init__(
7778
use_nesterov: bool = False,
7879
weight_decay_method: WeightDecayT = "decoupled",
7980
fp32_matmul_prec: FP32MatmulPrecT = "medium",
80-
coefficient_type: str = "quintic",
81+
coefficient_type: NSCoeffT = "quintic",
8182
num_ns_steps: int = 5,
8283
scale_mode: MuonScaleT = "spectral",
8384
extra_scale_factor: float = 1.0,

emerging_optimizers/orthogonalized_optimizers/muon_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
__all__ = ["newton_schulz", "newton_schulz_tp"]
2424

25+
NSCoeffT = Literal["simple", "quintic", "polar_express", "aol", "custom"]
26+
2527
_COEFFICIENT_SETS = {
2628
"simple": [
2729
(3.4445, -4.7750, 2.0315),
@@ -67,7 +69,7 @@ def distributed_normalize_p2(x: torch.Tensor, eps: float, group: torch.distribut
6769
def newton_schulz(
6870
x: torch.Tensor,
6971
steps: int,
70-
coefficient_type: str = "quintic",
72+
coefficient_type: NSCoeffT = "quintic",
7173
custom_coefficient_sets: list[tuple[float, float, float]] | None = None,
7274
eps: float = 1e-7,
7375
transpose: bool | None = None,
@@ -164,7 +166,7 @@ def newton_schulz(
164166
def newton_schulz_tp(
165167
x: torch.Tensor,
166168
steps: int,
167-
coefficient_type: str,
169+
coefficient_type: NSCoeffT,
168170
tp_group: torch.distributed.ProcessGroup,
169171
partition_dim: int | None = None,
170172
mode: Literal["duplicated", "distributed"] = "duplicated",

emerging_optimizers/orthogonalized_optimizers/scion.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from absl import logging
1818
from torch.optim.optimizer import ParamsT
1919

20+
from emerging_optimizers.orthogonalized_optimizers import muon_utils
2021
from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor
21-
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
22+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
2223
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer
2324
from emerging_optimizers.utils import FP32MatmulPrecT
2425

@@ -63,7 +64,7 @@ def __init__(
6364
momentum_beta: float = 0.95,
6465
*,
6566
fp32_matmul_prec: FP32MatmulPrecT = "medium",
66-
coefficient_type: str = "quintic",
67+
coefficient_type: NSCoeffT = "quintic",
6768
num_ns_steps: int = 5,
6869
spectral_radius: float = 1.0,
6970
) -> None:
@@ -84,7 +85,9 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
8485
logging.debug(
8586
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, spectral_radius={spectral_radius}"
8687
)
87-
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=False)
88+
orth_grad = muon_utils.newton_schulz(
89+
grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=False
90+
)
8891
width_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode="unit_rms_norm")
8992
return orth_grad * width_factor * spectral_radius
9093

0 commit comments

Comments
 (0)