Skip to content

Commit 27aa01c

Browse files
committed
comment update
Signed-off-by: Hao Wu <[email protected]>
1 parent faee14c commit 27aa01c

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def __init__(
9393

9494
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
9595
logging.debug(
96-
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, {scale_mode} scale mode, extra_scale_factor={extra_scale_factor}"
96+
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, "
97+
f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}"
9798
)
9899
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk)
99100
scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@
3636
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
3737
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
3838
use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
39-
split_fused: Whether to split fused parameters (QKV, GQA, etc.) for preconditioning, default to be False.
40-
is_fused_fn: Function to check if a parameter is fused parameters (QKV, GQA, etc.).
41-
If multiple types of parameters are fused, the function should return True for all of which needs to be
42-
split for preconditioning.
43-
split_fn: Function to split the fused parameters (QKV, GQA, etc.) into a list of parameters.
44-
It should support all the types of parameters that is_fused_fn returns True for.
4539
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
4640
"""
4741

@@ -70,8 +64,7 @@ class OrthogonalizedOptimizer(optim.Optimizer):
7064
7165
Args:
7266
{_args_doc}
73-
orthogonalize_fn: Function to orthogonalize the updates.
74-
scale_factor_fn: Function to compute the scale factor for the update.
67+
scaled_orthogonalize_fn: Function to orthogonalize and scale the updates.
7568
**kwargs: Arguments passed through to the base optimizer.
7669
7770
Note:
@@ -155,21 +148,29 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
155148
grad = exp_avg
156149

157150
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
158-
grad = self.orthogonalize(p, grad)
151+
group_kwargs = {k: v for k, v in group.items() if k != "params"}
152+
grad = self.orthogonalize(p, grad, **group_kwargs)
159153

160154
# perform weight update
161155
# scale is applied to have update RMS == 1
162156
p.add_(grad, alpha=-group["lr"])
163157

164158
return loss
165159

166-
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
160+
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
167161
"""Orthogonalize the momentum.
168162
163+
The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can
164+
override this function to implement different orthogonalization logic as well as split fused parameters.
165+
For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if
166+
the parameter is a fused parameter and should be split for preconditioning.
167+
169168
Args:
170-
p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of
171-
information is only available in the param tensor, attributes for example.
169+
p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of
170+
information is only available in the param tensor, attributes for example. Although not used in
171+
this default orthogonalize function.
172172
grad: The momentum tensor.
173+
**kwargs: keyword arguments of the param_group that p was belonged to.
173174
174175
Returns:
175176
The orthogonalized gradient tensor.

0 commit comments

Comments
 (0)