Skip to content

Commit e39c9b1

Browse files
committed
cleaned up adaptive orthogonalized optimizer
Signed-off-by: mikail <[email protected]>
1 parent 9cf84db commit e39c9b1

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Callable, Literal
15+
from typing import Callable, Literal
1616

1717

1818
# TODO(@boxiangw): remove this once bump to python 3.12
@@ -22,10 +22,13 @@
2222
from typing_extensions import override
2323

2424
import torch
25+
from absl import logging
2526
from torch.optim.optimizer import ParamsT
2627

2728
from emerging_optimizers import mixin as opt_mixin
2829
from emerging_optimizers import utils
30+
from emerging_optimizers.orthogonalized_optimizers import muon_utils
31+
from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor
2932
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer
3033

3134

@@ -73,21 +76,37 @@ def __init__(
7376
eps: float = 1e-8,
7477
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
7578
fp32_matmul_prec: str,
76-
scaled_orthogonalize_fn: Callable | None = None,
77-
**kwargs: Any,
79+
coefficient_type: str = "quintic",
80+
num_ns_steps: int = 5,
81+
scale_mode: str = "spectral",
82+
extra_scale_factor: float = 1.0,
83+
use_syrk: bool = False,
7884
):
7985
self.second_moment_method = second_moment_method
8086

87+
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
88+
logging.debug(
89+
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, "
90+
f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}"
91+
)
92+
orth_grad = muon_utils.newton_schulz(
93+
grad,
94+
steps=num_ns_steps,
95+
coefficient_type=coefficient_type,
96+
use_syrk=use_syrk,
97+
)
98+
scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
99+
return orth_grad * scale_factor * extra_scale_factor
100+
81101
super().__init__(
82-
params=params,
83-
lr=lr,
84-
momentum_beta=momentum_beta,
85-
weight_decay=weight_decay,
102+
params,
103+
lr,
104+
momentum_beta,
86105
use_nesterov=use_nesterov,
106+
weight_decay=weight_decay,
87107
weight_decay_method=weight_decay_method,
88108
fp32_matmul_prec=fp32_matmul_prec,
89109
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
90-
**kwargs,
91110
)
92111

93112
for group in self.param_groups:
@@ -154,7 +173,7 @@ def _apply_second_moment_normalization(
154173
"""
155174
if self.second_moment_method == "adamuon":
156175
# AdamMuon: Full elementwise second moment like AdamW
157-
# Update second moment with EMA of squared gradient
176+
# Update second moment with EMA of squared orthogonalized gradient
158177
second_moment.lerp_(orth_grad.square(), 1 - beta2)
159178

160179
# AdamW-style division: grad / (sqrt(second_moment) + eps)
@@ -224,8 +243,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
224243
grad = exp_avg
225244

226245
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
227-
group_kwargs = {k: v for k, v in group.items() if k != "params"}
228-
grad = self.orthogonalize(p, grad, **group_kwargs)
246+
grad = self.scaled_orthogonalize_fn(grad)
229247

230248
# Apply second moment normalization
231249
grad = self._apply_second_moment_normalization(

0 commit comments

Comments
 (0)