Add SSO (Spectral Sphere Optimizer) and MuonSphere optimizers#2466
Add SSO (Spectral Sphere Optimizer) and MuonSphere optimizers#2466WhenWen wants to merge 1 commit intomarin-community:mainfrom
Conversation
This commit adds two new optimizers for training language models: 1. SSO (Spectral Sphere Optimizer): Full spectral sphere optimization with lambda solver - Retracts 2D weight matrices to spectral sphere with radius R = radius_scaler * sqrt(d_out/d_in) - Applies msign update (matrix sign function via Newton-Schulz iteration) - Solves for lambda to enforce tangent constraint 2. MuonSphere: Simplified version with lambda=0 - Same as SSO but without the lambda solver for faster computation Key features: - Support for scan layers (automatically vmaps over layer dimension) - Polar Express Newton-Schulz coefficients for msign computation - Power iteration for top singular value/vector estimation - Compatible with haliax partitioning system - Includes example experiment for radius_scaler sweep Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
|
A quick test run here https://api.wandb.ai/links/marin-community/emank8v9 |
| } | ||
|
|
||
|
|
||
| def msign_newton_schulz( |
There was a problem hiding this comment.
@WhenWen why not reuse/replace zeropower_via_newtonschulz5 from levanter.optim.muon?
Oh my, I just remembered we were using the old coefficients this entire time 👀
There was a problem hiding this comment.
Yeah we should likely update Muon to allow selecting a better coefficient. Will write a PR
There was a problem hiding this comment.
But for SSO I have been using polar express with step 8 lol
There was a problem hiding this comment.
We should probably move this function to lib/levanter/src/levanter/optim/utils.py. Wdyt?
|
This pull request has been inactive for 23 days and is marked as stale. |
This commit adds two new optimizers for training language models from https://arxiv.org/abs/2601.08393:
SSO (Spectral Sphere Optimizer): Full spectral sphere optimization with lambda solver
MuonSphere: Simplified version with lambda=0
Key features: