Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion emerging_optimizers/orthogonalized_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Muon(OrthogonalizedOptimizer):
Args:
{_args_doc}
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of
["simple", "quintic", "polar_express"].
["simple", "quintic", "polar_express", "cans"].
num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration.
scale_mode: The type of scale factor to use for the update. Defaults to "spectral" style scaling.
extra_scale_factor: The additional scale factor to use for the update. Setting it to 0.2 can closely match
Expand Down
16 changes: 14 additions & 2 deletions emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

CoeffIterMode = Literal["cycle", "repeat_last"]

NSCoeffT = Literal["simple", "quintic", "polar_express", "aol", "custom"]
NSCoeffT = Literal["simple", "quintic", "polar_express", "cans", "aol", "custom"]

_COEFFICIENT_SETS = {
# Values are rounded to closest representable in single precision.
Expand Down Expand Up @@ -55,6 +55,16 @@
(1.8564, -1.2132, 0.3568),
(1.8750, -1.2500, 0.3750),
],
"cans": [
# CANS from: http://arxiv.org/abs/2506.10935
# CANS iteration (Remez + adaptive interval) based coefficients.
# Source (for generating CANS coefficients): https://github.com/GrishKate/accelerating_orthogonalization/blob/main/polynomials.py
(8.4703, -25.1081, 18.6293),
(4.1828, -3.1087, 0.5806),
(3.9619, -2.9541, 0.5630),
(3.2866, -2.4647, 0.5074),
(2.2737, -1.6447, 0.4162),
],
"aol": [
# from https://github.com/thib-s/flash-newton-schulz/blob/main/newton_schulz_triton.py#L511
(4.0098, -7.0585, 2.4635),
Expand Down Expand Up @@ -136,6 +146,8 @@ def newton_schulz(
- "simple": Default coefficient set.
- "quintic": Quintic iteration with optimized coefficients.
- "polar_express": Polar Express iteration with optimized coefficients.
- "cans": CANS iteration with Remez + adaptive interval coefficients.
- "aol": AOL coefficient set.
- "custom": Custom coefficient sets.

Arguments:
Expand Down Expand Up @@ -179,7 +191,7 @@ def newton_schulz(
else:
raise ValueError(f"Invalid coefficient type: {coefficient_type}")

iter_mode: CoeffIterMode = "cycle" if coefficient_type != "polar_express" else "repeat_last"
iter_mode: CoeffIterMode = "repeat_last" if coefficient_type in ("polar_express", "cans") else "cycle"
coeff_iter = get_coefficient_iterator(steps, coefficient_sets, mode=iter_mode)

ns_step_fn = newton_schulz_step
Expand Down
2 changes: 1 addition & 1 deletion emerging_optimizers/orthogonalized_optimizers/polargrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class PolarGrad(OrthogonalizedOptimizer):
Args:
{_args_doc}
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of
["simple", "quintic", "polar_express"].
["simple", "quintic", "polar_express", "cans"].
num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration.
extra_scale_factor: The additional scale factor to use for the update. Setting it to 0.2 can closely match
the update RMS norm of AdamW as suggested by https://arxiv.org/abs/2502.16982.
Expand Down
2 changes: 1 addition & 1 deletion emerging_optimizers/orthogonalized_optimizers/scion.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Scion(OrthogonalizedOptimizer):
momentum: The momentum used by the internal SGD.
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of
["simple", "quintic", "polar_express"].
["simple", "quintic", "polar_express", "cans"].
num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration.
spectral_radius: The spectral radius to use for the update, we are scaling the LMO by this spectral radius.
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/test_muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,37 @@ def test_get_polar_express_9steps_close_to_reference(self, dim1, dim2):
out_ref = newton_schulz_ref(x, coefficient_sets=coeff)
torch.testing.assert_close(out_pe9, out_ref, atol=2e-6, rtol=1e-7)

@parameterized.parameters(
(512, 512),
(512, 256),
(256, 512),
)
def test_cans_close_to_reference(self, dim1, dim2):
x = torch.randn(dim1, dim2, device=self.device, dtype=torch.float32)
out_cans_test = muon_utils.newton_schulz(x, steps=5, coefficient_type="cans")
out_cans_ref = newton_schulz_ref(x, coefficient_sets=muon_utils._COEFFICIENT_SETS["cans"])

torch.testing.assert_close(
out_cans_test,
out_cans_ref,
atol=1e-5,
rtol=1e-7,
)

@parameterized.parameters(
(511, 513),
(511, 257),
(257, 513),
)
def test_get_cans_9steps_close_to_reference(self, dim1, dim2):
x = torch.randn(dim1, dim2, device=self.device, dtype=torch.float32)
out_cans9 = muon_utils.newton_schulz(x, steps=9, coefficient_type="cans")
coeff = deepcopy(muon_utils._COEFFICIENT_SETS["cans"])
# CANS uses repeat_last, so repeat the last tuple for remaining steps.
coeff.extend([coeff[-1]] * 4)
out_ref = newton_schulz_ref(x, coefficient_sets=coeff)
torch.testing.assert_close(out_cans9, out_ref, atol=2e-6, rtol=1e-7)


@absltest.skipIf(
_SM_VERSION not in ((8, 0), (9, 0), (10, 0), (10, 3)),
Expand Down