-
Notifications
You must be signed in to change notification settings - Fork 584
feat(pt): add AdaMuon optimizer #5130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds the AdaMuon optimizer to the DeePMD-kit PyTorch backend, which combines Newton-Schulz orthogonalization with adaptive second-moment normalization for improved training stability. The optimizer applies different update rules based on parameter dimensionality: AdaMuon for 2D+ weight matrices and standard Adam for 1D parameters (biases, layer norms).
Key changes:
- Implementation of AdaMuonOptimizer with sign-stabilized orthogonal direction updates and per-element normalization
- Comprehensive test suite covering basic functionality, state management, bucketing, learning rate adjustment, and weight decay
- Integration with the training pipeline including configuration arguments and optimizer initialization
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
deepmd/pt/optimizer/adamuon.py |
New AdaMuonOptimizer implementation with Newton-Schulz orthogonalization and adaptive normalization |
source/tests/pt/test_adamuon.py |
Comprehensive test suite covering optimizer functionality, state management, and various configurations |
deepmd/utils/argcheck.py |
Added configuration arguments for AdaMuon optimizer parameters (momentum, betas, weight_decay, lr_adjust) |
deepmd/pt/train/training.py |
Integrated AdaMuon optimizer into training pipeline with parameter extraction and scheduler setup |
deepmd/pt/optimizer/__init__.py |
Added AdaMuonOptimizer to module exports |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughAdds a new AdaMuonOptimizer and helpers (Newton–Schulz orthogonalization + momentum prep), integrates it into training arg parsing and Trainer optimizer flow, and introduces tests validating behavior, state handling, bucketing, lr_adjust modes, and weight decay. Changes
Sequence DiagramsequenceDiagram
participant Trainer
participant AdaMuonOptimizer
participant Segregator as Param<br/>Segregation
participant AdamPath as Adam<br/>(1D)
participant MuonPath as AdaMuon<br/>(≥2D)
participant Buckets
participant NS as Newton-Schulz
participant Scale as Norm & Scale
participant Params as Parameter<br/>Update
Trainer->>AdaMuonOptimizer: step()
AdaMuonOptimizer->>Segregator: separate params by ndim
Segregator-->>AdamPath: 1D params
Segregator-->>MuonPath: ≥2D params
AdamPath->>AdamPath: update exp_avg/exp_avg_sq (FP32) and compute step
AdamPath->>Params: apply scaled Adam updates
MuonPath->>MuonPath: apply weight decay, update momentum buffers
MuonPath->>Buckets: reshape to (M,N) and group by shape/device
Buckets->>NS: stack sign matrices, run Newton–Schulz
NS-->>Scale: return orthogonal directions (bf16)
Scale->>Scale: update EMA v_buffer, compute per-element norm, apply RMS scaling and lr_adjust
Scale->>Params: apply orthogonalized, scaled updates (-lr × update)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)
158-175: Pass collectedadam_epsandnesterovparameters toAdaMuonOptimizer, or remove them fromget_opt_param()if fixed
get_opt_param()collectsadam_eps(default1e-7) andnesterov(defaultTrue) intoopt_param, but the AdaMuon branch (lines 709–720) does not pass these to theAdaMuonOptimizerconstructor. They are functional parameters inside AdaMuonOptimizer (used in line 334 for the 1D parameter denominator and line 141 for momentum update logic), but they always use the optimizer's constructor defaults instead of the config values.Either thread them through explicitly:
adam_eps=float(self.opt_param.get("adam_eps", 1e-8)), nesterov=bool(self.opt_param.get("nesterov", True)),or remove them from
get_opt_param()to avoid a misleading API surface.The shared
weight_decaydefault of0.001is consistent and correct.
🧹 Nitpick comments (7)
deepmd/utils/argcheck.py (1)
3373-3427: Align AdaMuon training schema with actual optimizer parametersThe new
AdaMuonvariant underopt_typeexposesmuon_momentum,adam_beta1/2,weight_decay,lr_adjust, andlr_adjust_coeff, which matches most of whatTrainerpasses intoAdaMuonOptimizer. However:
get_opt_param()indeepmd/pt/train/training.pyalso populatesadam_epsandnesterovinopt_param, but those values are never forwarded intoAdaMuonOptimizer’s constructor.- There is currently no public schema here for
adam_epsornesterov, so they are effectively non‑configurable and theopt_paramentries are dead.Consider either:
- Exposing
adam_epsandnesterovhere and passing them through when constructingAdaMuonOptimizer, or- Dropping them from
opt_paramif you intend to keep them fixed at the optimizer defaults.This keeps the configuration surface and runtime behavior in sync.
deepmd/pt/train/training.py (1)
798-804: AdaMuon integrated into main SGD path; multi-task still not supportedIncluding
"AdaMuon"in the mainif self.opt_type in ["Adam", "AdamW", "AdaMuon"]:block correctly reuses the Adam-style step and LR scheduling logic for the new optimizer.Note that in multi-task mode
self.opt_typebecomes a dict, and this branch (as indicated by the nearby TODO) still doesn’t support per-task optimizers. If AdaMuon is not meant to be used with multi-task training yet, that’s fine; otherwise you may want an explicit check/error whenmulti_task=Trueandopt_type == "AdaMuon"is requested.deepmd/pt/optimizer/adamuon.py (3)
40-110: Newton–Schulz helper looks sound; consider future configurabilityThe
zeropower_via_newtonschulz5implementation matches the described quintic Newton–Schulz scheme (bf16 compute, Frobenius normalization, tall-matrix transpose, fused GEMMs), and the guards onndimandstepsare reasonable.If you later need more flexibility, two low-impact extensions would be:
- Allowing an optional
dtype/use_bfloat16flag so callers can keep everything in FP32 on hardware where bf16 GEMMs are slow or unavailable.- Relaxing the
steps >= 100hard check to a warning or configurable cap.Neither is blocking; the current implementation is fine as-is.
112-148: Momentum preparation helper: clarify Nesterov behavior
_prepare_muon_momentumcleanly encapsulates the momentum update and reshape logic, but the Nesterov path:momentum_buffer.lerp_(grad, 1 - beta) update = grad.lerp(momentum_buffer, beta) if nesterov else momentum_bufferdoes not match the standard
v = β v + g; u = g + β vNesterov formulation. If this is intentionally following the AdaMuon reference implementation, a short comment would help avoid future “fixes” that change behavior. If the goal was standard Nesterov, consider switching to the canonical update.Also,
original_shape/reshape handling for >2D params is clear and works well.
256-407: AdaMuonOptimizer step: solid design with a small style nitThe overall
step()logic is well-structured:
- Clean separation of >=2D (AdaMuon) vs 1D (FP32 Adam) parameters.
- Decoupled multiplicative weight decay only on >=2D params, as documented.
- Per-parameter Adam state with bias correction via
beta1_pow/beta2_pow.- Efficient bucketing by
(rows, cols, device)plus batched Newton–Schulz on stacked sign matrices.- Per-element
v_bufferEMA in FP32, followed by RMS-aligned scaling and shape-dependentadj_scale.Two minor points:
In the bucket loop, the
updatevariable from the tuple is never read:for i, (p, update, orig_shape) in enumerate(bucket): state = self.state[p] orth_vec = orth_stacked[i].flatten().float() ...To satisfy linters and future readers, you can rename it to
_update:for i, (p, _update, orig_shape) in enumerate(bucket): ...If you later want users to tune Newton–Schulz steps,
ns_stepsis already plumbed through the param group; exposing it via config (as you did withlr_adjust) would be straightforward.Functionally this all looks correct and well tested given the dedicated
test_adamuon.pysuite.source/tests/pt/test_adamuon.py (2)
9-14: Torch thread monkeypatch in tests: keep but simplify lambda signaturesThe global monkeypatch of
torch.set_num_interop_threads/torch.set_num_threadsto no-ops is understandable to avoid thread reconfiguration warnings in the test environment.To quiet linters and make the intent a bit clearer, you could drop the unused lambda parameters:
if torch_set_num_interop_threads is not None: torch.set_num_interop_threads = lambda *_, **__: None # type: ignore[assignment] if torch_set_num_threads is not None: torch.set_num_threads = lambda *_, **__: None # type: ignore[assignment]Functionally equivalent, just a small style tweak.
100-106: Usestrict=Trueinzipin tests for extra safetyBoth here:
for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)): ...and later when zipping
params1/params2for state dict comparison, addingstrict=Truetozipwould make mismatched lengths fail loudly instead of silently truncating:for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params, strict=True)): ... for p1, p2 in zip(params1, params2, strict=True): ...This is purely a robustness/readability improvement; the current tests are otherwise very solid and comprehensive.
Also applies to: 372-383
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pt/optimizer/__init__.pydeepmd/pt/optimizer/adamuon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_adamuon.py
🧰 Additional context used
🧬 Code graph analysis (2)
deepmd/pt/train/training.py (2)
deepmd/pt/optimizer/adamuon.py (2)
AdaMuonOptimizer(151-409)step(234-409)deepmd/pd/train/training.py (2)
step(734-978)warm_up_linear(586-590)
deepmd/pt/optimizer/__init__.py (1)
deepmd/pt/optimizer/adamuon.py (1)
AdaMuonOptimizer(151-409)
🪛 Ruff (0.14.10)
source/tests/pt/test_adamuon.py
12-12: Unused lambda argument: args
(ARG005)
12-12: Unused lambda argument: kwargs
(ARG005)
14-14: Unused lambda argument: args
(ARG005)
14-14: Unused lambda argument: kwargs
(ARG005)
100-100: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
372-372: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
deepmd/pt/optimizer/adamuon.py
77-77: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
375-375: Loop control variable update not used within loop body
(B007)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (30)
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Agent
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
🔇 Additional comments (1)
deepmd/pt/optimizer/__init__.py (1)
2-4: AdaMuon export wiring looks correctImporting
AdaMuonOptimizerhere and adding it to__all__cleanly exposes the new optimizer viadeepmd.pt.optimizer. No issues from this change.Also applies to: 12-12
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/adamuon.py:
- Around line 40-109: The docstring for zeropower_via_newtonschulz5 should be
updated to explicitly state that the hard-coded coefficients (3.4445, -4.7750,
2.0315) are the "cursed quintic" from Muon/NanoGPT and that this variant
intentionally trades formal convergence guarantees for faster wall-clock
orthogonalization (they sum to ~0.701, so x=1 is not a fixed point and the
iteration can oscillate); revise the description to call this a
non-convergent/quasi-Newton-Schulz heuristic, warn users about possible
oscillation and lack of convergence guarantees, and mention that it operates in
bfloat16 and returns bfloat16 so numerical behavior may differ from a standard
convergent Newton-Schulz implementation (keep the rest of the function logic
unchanged).
🧹 Nitpick comments (3)
deepmd/pt/train/training.py (1)
707-725: Consider exposinglr_adjustandlr_adjust_coeffinget_opt_param.The initialization correctly retrieves
lr_adjustandlr_adjust_coefffromopt_param, but these parameters are not added toget_opt_param(lines 158-173). They will always use the defaults (10.0 and 0.2). If users should be able to configure these via the training config, consider adding them:# Muon/AdaMuon parameters "momentum": params.get("momentum", 0.95), "adam_beta1": params.get("adam_beta1", 0.9), "adam_beta2": params.get("adam_beta2", 0.95), + "lr_adjust": params.get("lr_adjust", 10.0), + "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),Otherwise, if this is intentional to keep the interface simpler, no change is needed.
deepmd/pt/optimizer/adamuon.py (2)
373-407: Unused loop variableupdatein tuple unpacking.At line 375, the loop variable
updateis unpacked but never used within the loop body. This is because theupdatetensor was already processed when building thestackedtensor at line 370. Consider renaming to_to indicate it's intentionally unused:♻️ Suggested fix
- for i, (p, update, orig_shape) in enumerate(bucket): + for i, (p, _update, orig_shape) in enumerate(bucket):
313-325: Add a comment documenting reliance ontorch._foreach_*private APIs.While
torch._foreach_mul_,torch._foreach_lerp_, andtorch._foreach_mulare private APIs (explicitly marked as subject to change without warning), PyTorch's official Adam optimizer uses them for performance-critical batch operations. Consider adding an inline comment explaining this deliberate choice, noting that these private APIs are used for optimization purposes and may require updates if PyTorch changes their API in future versions.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/optimizer/adamuon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/utils/argcheck.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/pt/optimizer/adamuon.py
🧬 Code graph analysis (2)
deepmd/pt/optimizer/adamuon.py (2)
deepmd/pt/train/training.py (1)
step(772-1144)source/tests/pt/test_adamuon.py (1)
closure(328-334)
deepmd/pt/train/training.py (2)
deepmd/pt/optimizer/adamuon.py (2)
AdaMuonOptimizer(151-409)step(234-409)deepmd/pd/train/training.py (1)
step(734-978)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/adamuon.py
77-77: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
375-375: Loop control variable update not used within loop body
(B007)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (35)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
🔇 Additional comments (8)
deepmd/pt/train/training.py (3)
44-48: LGTM!The import of
AdaMuonOptimizeris correctly placed with other optimizer imports and aligns with the module structure.
166-172: LGTM!The new optimizer parameters are properly documented with comments and their default values align with the
AdaMuonOptimizerconstructor.
796-814: LGTM!Adding
"AdaMuon"to the optimizer type check correctly integrates it with the existing training loop, including warmup handling, gradient clipping, and scheduler stepping.deepmd/pt/optimizer/adamuon.py (5)
1-38: LGTM!Well-structured module with clear documentation explaining the optimizer's purpose and key improvements. The reference to the source repository is helpful for traceability.
112-148: LGTM!The momentum update and Nesterov handling are correctly implemented. The reshape logic properly handles higher-dimensional tensors for matrix-based processing.
205-231: LGTM!Clean constructor following the standard PyTorch
Optimizerpattern with well-documented defaults.
391-399: Numerical stability handling looks solid.The implementation properly guards against division by zero with
epsadditions at critical points (v_buffer sqrt at line 394, norm division at line 399). The use of FP32 for accumulator buffers (v_buffer,exp_avg,exp_avg_sq) helps maintain precision.
151-203: Well-documented class with comprehensive docstring.The class docstring clearly explains the dual update strategy (AdaMuon for >=2D, Adam for 1D params), key features, and parameter semantics including the
lr_adjustmode behavior. This aids usability significantly.
| def zeropower_via_newtonschulz5( | ||
| G: torch.Tensor, | ||
| steps: int = 5, | ||
| eps: float = 1e-8, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration. | ||
| Uses quintic Newton-Schulz iteration to compute the orthogonal component of the | ||
| input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T. | ||
| This implementation always performs Newton-Schulz in bfloat16 and returns a | ||
| bfloat16 tensor. | ||
| Parameters | ||
| ---------- | ||
| G : torch.Tensor | ||
| Input matrix to orthogonalize with shape (..., M, N). | ||
| steps : int | ||
| Number of Newton-Schulz iterations with default 5. | ||
| eps : float | ||
| Numerical stability epsilon for norm clamping with default 1e-8. | ||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Orthogonalized matrix in bfloat16 with same shape as input. | ||
| Raises | ||
| ------ | ||
| ValueError | ||
| If G has fewer than 2 dimensions. | ||
| ValueError | ||
| If steps >= 100 (guard for efficiency). | ||
| """ | ||
| # === Step 1. Validate === | ||
| if G.ndim < 2: | ||
| raise ValueError("Input must have at least 2 dimensions (..., M, N).") | ||
| if steps >= 100: | ||
| raise ValueError("Number of steps must be less than 100 for efficiency.") | ||
|
|
||
| a, b, c = (3.4445, -4.7750, 2.0315) | ||
|
|
||
| # === Step 2. Cast to bf16 === | ||
| X = G.to(dtype=torch.bfloat16) | ||
|
|
||
| # === Step 3. Transpose tall matrices === | ||
| if X.size(-2) > X.size(-1): | ||
| X = X.mT | ||
|
|
||
| # === Step 4. Normalize Frobenius norm to at most 1 === | ||
| X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=eps) | ||
|
|
||
| # === Step 5. Newton-Schulz iterations with fused GEMM === | ||
| for _ in range(steps): | ||
| A = X @ X.mT | ||
| # gram_update = b*A + c*(A@A) via addmm/baddbmm | ||
| # X = a*X + gram_update@X via addmm/baddbmm | ||
| if X.ndim == 2: | ||
| gram_update = torch.addmm(A, A, A, beta=b, alpha=c) | ||
| X = torch.addmm(X, gram_update, X, beta=a, alpha=1.0) | ||
| else: | ||
| gram_update = torch.baddbmm(A, A, A, beta=b, alpha=c) | ||
| X = torch.baddbmm(X, gram_update, X, beta=a, alpha=1.0) | ||
|
|
||
| # === Step 6. Transpose back if needed === | ||
| if G.size(-2) > G.size(-1): | ||
| X = X.mT | ||
|
|
||
| return X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
Newton-Schulz quintic iteration coefficients 3.4445 -4.7750 2.0315
💡 Result:
Those numbers are the "cursed quintic" Newton–Schulz polynomial used in the Muon/NanoGPT speedrun:
f(x) = 3.4445 x − 4.7750 x^3 + 2.0315 x^5.
It was chosen to rapidly inflate small singular values (faster wall‑clock orthogonalization) but it is non‑convergent — its coefficients sum to 0.701 ≠ 1, so x=1 is not a fixed point and the iteration can oscillate. [1][2]
Sources:
- Modula docs (discussion of the cursed quintic and coefficients). [1]
- ArXiv paper on Chebyshev‑optimized Newton‑Schulz (context on higher‑degree Newton–Schulz polynomials). [2]
Update docstring to disclose non-convergent trade-off and clarify these are the "cursed quintic" coefficients.
The coefficients (3.4445, -4.7750, 2.0315) are the "cursed quintic" from Muon/NanoGPT, designed for speed. However, these do not form a standard convergent Newton-Schulz iteration—they sum to 0.701 ≠ 1, so x=1 is not a fixed point, and the iteration can oscillate.
The docstring should clarify that this trades convergence guarantees for faster wall-clock orthogonalization, not that it is a standard Newton-Schulz method. Users should know they are using a non-convergent variant with intentional trade-offs.
🧰 Tools
🪛 Ruff (0.14.10)
77-77: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In @deepmd/pt/optimizer/adamuon.py around lines 40 - 109, The docstring for
zeropower_via_newtonschulz5 should be updated to explicitly state that the
hard-coded coefficients (3.4445, -4.7750, 2.0315) are the "cursed quintic" from
Muon/NanoGPT and that this variant intentionally trades formal convergence
guarantees for faster wall-clock orthogonalization (they sum to ~0.701, so x=1
is not a fixed point and the iteration can oscillate); revise the description to
call this a non-convergent/quasi-Newton-Schulz heuristic, warn users about
possible oscillation and lack of convergence guarantees, and mention that it
operates in bfloat16 and returns bfloat16 so numerical behavior may differ from
a standard convergent Newton-Schulz implementation (keep the rest of the
function logic unchanged).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/adamuon.py:
- Around line 44-113: The default eps=1e-8 in zeropower_via_newtonschulz5 is too
small for pure bfloat16 work—raise the default to at least 1e-3 (preferably
1e-1) or perform the normalization in float32 before casting to bfloat16:
compute norm = G.norm(dim=(-2,-1), keepdim=True).clamp(min=eps) in float32 and
divide G by that norm, then cast the result to torch.bfloat16 for the
Newton‑Schulz iterations; ensure the function signature default eps is updated
(or validated) to >=1e-3 if you choose the simpler default change.
🧹 Nitpick comments (1)
deepmd/pt/optimizer/adamuon.py (1)
379-379: Minor: Loop variableupdateis unused.The loop unpacks
(p, update, orig_shape)but only usespandorig_shape. Theupdatevariable is unused because the orthogonalized sign matrix is retrieved viaorth_stacked[i]instead.While this is intentional (update was already used for sign computation in line 374), consider renaming to
_for clarity:-for i, (p, update, orig_shape) in enumerate(bucket): +for i, (p, _, orig_shape) in enumerate(bucket):
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/optimizer/adamuon.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/pt/optimizer/adamuon.py
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/adamuon.py
81-81: Avoid specifying long messages outside the exception class
(TRY003)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
379-379: Loop control variable update not used within loop body
(B007)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (34)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (8)
deepmd/pt/optimizer/adamuon.py (8)
116-152: LGTM!The momentum preparation logic is clean and correct. The in-place momentum update is properly documented, and the Nesterov logic follows standard conventions.
209-235: LGTM!The initialization follows PyTorch optimizer conventions correctly, with sensible default values that match the documentation.
255-316: LGTM!The parameter collection and state initialization logic is well-structured. Key strengths:
- Clean separation of AdaMuon (>=2D) and Adam (1D) paths.
- Adam state tensors maintained in FP32 for numerical stability.
- Proper bias correction setup with beta power tracking.
318-341: LGTM!The weight decay and Adam update implementation is solid:
- Decoupled weight decay (AdamW-style) is correctly applied before updates.
- Efficient batched operations via
_foreachprimitives.- Full FP32 computation path for Adam with final dtype cast ensures numerical stability.
343-370: LGTM!The bucketing strategy and scaling computation are well-designed:
- Efficient batching by shape and device for Newton-Schulz.
- Dual lr_adjust behavior provides flexibility for different scaling strategies.
- Pre-computed bucket-level constants avoid redundant computation.
371-376: LGTM!The sign-stabilized orthogonalization is correctly implemented:
contiguous()call ensures proper memory layout before batched operations.- Stacking sign matrices enables efficient batched Newton-Schulz processing.
409-411: LGTM!The final parameter update correctly:
- Reshapes the orthogonalized vector back to original shape.
- Casts to parameter dtype.
- Applies with negative learning rate scaling.
387-398: The review comment's analysis is incomplete and misses a crucial mitigating factor.The v_buffer is initialized to zeros as stated, and the element-wise normalization calculation on the first step is mathematically correct:
orth_vec / sqrt((1-momentum) * orth_vec²) ≈ sign(orth_vec) / sqrt(1-momentum). However, the review fails to account for the RMS-aligned global scaling applied immediately after (line 400-403), which significantly mitigates the first-step effect:
- After element-wise normalization (line 398): scaling factor ≈ 4.47x per-element
- Immediately after, global RMS-aligned scaling (line 403):
orth_vec / norm(orth_vec) * 0.2 * sqrt(m*n)- The global scaling re-normalizes the vector and applies a 0.2 coefficient, substantially dampening the first-step update
Additionally, zero initialization of the second-moment buffer is the standard practice in reference AdaMuon implementations and matches the paper's algorithm. This is intentional and not an instability concern—the RMS-aligned global scaling (0.2 * sqrt(m*n) / norm) is the paper's mechanism for controlling first-step behavior.
No warm-up schedule or non-zero initialization is needed.
Likely an incorrect or invalid review comment.
| def zeropower_via_newtonschulz5( | ||
| G: torch.Tensor, | ||
| steps: int = 5, | ||
| eps: float = 1e-8, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration. | ||
| Uses quintic Newton-Schulz iteration to compute the orthogonal component of the | ||
| input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T. | ||
| This implementation always performs Newton-Schulz in bfloat16 and returns a | ||
| bfloat16 tensor. | ||
| Parameters | ||
| ---------- | ||
| G : torch.Tensor | ||
| Input matrix to orthogonalize with shape (..., M, N). | ||
| steps : int | ||
| Number of Newton-Schulz iterations with default 5. | ||
| eps : float | ||
| Numerical stability epsilon for norm clamping with default 1e-8. | ||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Orthogonalized matrix in bfloat16 with same shape as input. | ||
| Raises | ||
| ------ | ||
| ValueError | ||
| If G has fewer than 2 dimensions. | ||
| ValueError | ||
| If steps >= 100 (guard for efficiency). | ||
| """ | ||
| # === Step 1. Validate === | ||
| if G.ndim < 2: | ||
| raise ValueError("Input must have at least 2 dimensions (..., M, N).") | ||
| if steps >= 100: | ||
| raise ValueError("Number of steps must be less than 100 for efficiency.") | ||
|
|
||
| a, b, c = (3.4445, -4.7750, 2.0315) | ||
|
|
||
| # === Step 2. Cast to bf16 === | ||
| X = G.to(dtype=torch.bfloat16) | ||
|
|
||
| # === Step 3. Transpose tall matrices === | ||
| if X.size(-2) > X.size(-1): | ||
| X = X.mT | ||
|
|
||
| # === Step 4. Normalize Frobenius norm to at most 1 === | ||
| X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=eps) | ||
|
|
||
| # === Step 5. Newton-Schulz iterations with fused GEMM === | ||
| for _ in range(steps): | ||
| A = X @ X.mT | ||
| # gram_update = b*A + c*(A@A) via addmm/baddbmm | ||
| # X = a*X + gram_update@X via addmm/baddbmm | ||
| if X.ndim == 2: | ||
| gram_update = torch.addmm(A, A, A, beta=b, alpha=c) | ||
| X = torch.addmm(X, gram_update, X, beta=a, alpha=1.0) | ||
| else: | ||
| gram_update = torch.baddbmm(A, A, A, beta=b, alpha=c) | ||
| X = torch.baddbmm(X, gram_update, X, beta=a, alpha=1.0) | ||
|
|
||
| # === Step 6. Transpose back if needed === | ||
| if G.size(-2) > G.size(-1): | ||
| X = X.mT | ||
|
|
||
| return X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
bfloat16 numerical precision mantissa bits decimal digits representation
💡 Result:
- Bit layout: 1 sign, 8 exponent, 7 fraction (mantissa) bits. [1][2]
- Effective significand precision (normalized numbers) = 7 + 1 (implicit leading 1) = 8 bits. [1][2]
- Decimal digits of precision ≈ log10(2^8) ≈ 2.41 → about 2–3 significant decimal digits. [2][3]
References
[1] oneAPI / oneDNN bfloat16 documentation.
[2] Wikipedia — Bfloat16 floating-point format.
[3] Explanation of bfloat16 bit layout and precision (bitstring / format docs).
🌐 Web query:
Newton-Schulz iteration epsilon eps parameter convergence reduced precision float16 bfloat16
💡 Result:
Short guidance and practical choices.
-
Machine-unit roundoff (u): float16 (binary16) has 10-bit fraction → u ≈ 2^-10 ≈ 9.77×10^-4. bfloat16 has 7-bit fraction → u ≈ 2^-7 = 7.8125×10^-3. Use these as the floor for any stopping tolerance. [1][2]
-
Stopping criterion (recommended): monitor the relative residual R_k = ||I - A X_k|| / ||A|| (or ||I - X_k A||). Stop when R_k ≤ C·u where C∈[1,100] depending on safety margin and whether you use FP32 accumulation. For pure float16/bfloat16 math use C≈10–100; for mixed-precision (FP16/BF16 compute, FP32 accumulators or casts) you can use C≈1–10. Typical practical tolerances:
- float16: eps ≈ 1e-3 … 1e-1 (safer: 1e-2).
- bfloat16: eps ≈ 8e-3 … 8e-2 (safer: 1e-1).
These bounds follow from the unit roundoff values and observed stability of Newton–Schulz under reduced precision. [1][3][4]
-
Iteration count: Newton–Schulz converges quadratically after normalization; in practice 3–6 iterations (often 3–5) suffice when normalized and using quintic/cubic variants used in ML orthogonalization. Stop earlier if R_k ≤ eps to avoid chasing rounding noise. [3][4]
-
Practical stability tips
- Always pre-normalize X0 (scale A so spectral norm ≤ 1) — this is essential for contractive Newton–Schulz behavior. Use mixed precision for the normalization constant if available. [3][5]
- Prefer mixed precision: do mat-muls in FP16/BF16 but accumulate or compute residuals in FP32 when possible (reduces effective error and lets you set smaller eps). [2][6]
- If only pure 16-bit is available, increase eps (looser tolerance) and/or keep a few critical buffers (scaling, residual) in FP32. [1][2][3]
- If you need a guaranteed final FP32-level accuracy, perform a final correction or one Newton step in FP32 with the 16-bit result as initial guess. [4][5]
References
- FP formats and bit-allocation (float16, bfloat16): Intel / ML systems summaries. [1][2]
- Newton–Schulz convergence and low-precision usage (MDPI study, Muon/nanochat discussions on ML orthogonalization and practical iterations). [3][4]
- Mixed-precision advice and TensorRT/BFloat16 practice notes. [5][6]
If you want, I can: (a) compute exact unit roundoffs and suggested numeric eps for your hardware, (b) suggest pseudocode for a robust stopping check using mixed precision, or (c) tune eps for a specific matrix condition number — tell me which.
🏁 Script executed:
cat -n deepmd/pt/optimizer/adamuon.py | sed -n '44,113p'Repository: deepmodeling/deepmd-kit
Length of output: 2782
Epsilon value is inappropriate for bfloat16 computation—use at least 1e-3, preferably 1e-1.
The function performs all Newton-Schulz iterations in pure bfloat16 (line 88) but defaults to eps=1e-8 (line 47). However, bfloat16 has only ~7-8 bits of mantissa precision and a machine unit roundoff of approximately 7.8×10⁻³. The default eps=1e-8 is ~1000× smaller than this roundoff, making it ineffective for norm clamping at line 95.
For pure bfloat16 Newton-Schulz iterations, recommended epsilon values are 1e-3 to 1e-1 (with 1e-1 being safer). Either:
- Increase the default
epsparameter to at least1e-3, or - Use mixed-precision normalization (perform the norm and clamping in float32 before casting to bfloat16).
This affects convergence stability and the effectiveness of the normalization guard against division by zero.
🧰 Tools
🪛 Ruff (0.14.10)
81-81: Avoid specifying long messages outside the exception class
(TRY003)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In @deepmd/pt/optimizer/adamuon.py around lines 44 - 113, The default eps=1e-8
in zeropower_via_newtonschulz5 is too small for pure bfloat16 work—raise the
default to at least 1e-3 (preferably 1e-1) or perform the normalization in
float32 before casting to bfloat16: compute norm = G.norm(dim=(-2,-1),
keepdim=True).clamp(min=eps) in float32 and divide G by that norm, then cast the
result to torch.bfloat16 for the Newton‑Schulz iterations; ensure the function
signature default eps is updated (or validated) to >=1e-3 if you choose the
simpler default change.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5130 +/- ##
==========================================
- Coverage 82.15% 81.96% -0.20%
==========================================
Files 709 713 +4
Lines 72468 73021 +553
Branches 3616 3616
==========================================
+ Hits 59535 59849 +314
- Misses 11769 12009 +240
+ Partials 1164 1163 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
New Features
Configuration
Tests
✏️ Tip: You can customize this high-level summary in your review settings.