Skip to content

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Jan 6, 2026

Summary by CodeRabbit

  • New Features

    • Added AdaMuon optimizer for PyTorch training with mixed 2D/1D parameter handling, batched orthogonalized updates, momentum and per-parameter state.
  • Configuration

    • New optimizer options: momentum, weight decay, Adam betas, lr adjustment modes and tuning coefficients; AdaMuon selectable as an optimizer.
  • Tests

    • Added comprehensive tests for orthogonalization, optimizer updates, bucketing, lr-adjust modes, weight decay, closure behavior, and state save/load.

✏️ Tip: You can customize this high-level summary in your review settings.

(cherry picked from commit 81d66ff)
(cherry picked from commit ea5ac54)
Copilot AI review requested due to automatic review settings January 6, 2026 15:42
@github-actions github-actions bot added the Python label Jan 6, 2026
@dosubot dosubot bot added the new feature label Jan 6, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds the AdaMuon optimizer to the DeePMD-kit PyTorch backend, which combines Newton-Schulz orthogonalization with adaptive second-moment normalization for improved training stability. The optimizer applies different update rules based on parameter dimensionality: AdaMuon for 2D+ weight matrices and standard Adam for 1D parameters (biases, layer norms).

Key changes:

  • Implementation of AdaMuonOptimizer with sign-stabilized orthogonal direction updates and per-element normalization
  • Comprehensive test suite covering basic functionality, state management, bucketing, learning rate adjustment, and weight decay
  • Integration with the training pipeline including configuration arguments and optimizer initialization

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
deepmd/pt/optimizer/adamuon.py New AdaMuonOptimizer implementation with Newton-Schulz orthogonalization and adaptive normalization
source/tests/pt/test_adamuon.py Comprehensive test suite covering optimizer functionality, state management, and various configurations
deepmd/utils/argcheck.py Added configuration arguments for AdaMuon optimizer parameters (momentum, betas, weight_decay, lr_adjust)
deepmd/pt/train/training.py Integrated AdaMuon optimizer into training pipeline with parameter extraction and scheduler setup
deepmd/pt/optimizer/__init__.py Added AdaMuonOptimizer to module exports

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 6, 2026

📝 Walkthrough

Walkthrough

Adds a new AdaMuonOptimizer and helpers (Newton–Schulz orthogonalization + momentum prep), integrates it into training arg parsing and Trainer optimizer flow, and introduces tests validating behavior, state handling, bucketing, lr_adjust modes, and weight decay.

Changes

Cohort / File(s) Summary
Optimizer implementation
deepmd/pt/optimizer/__init__.py, deepmd/pt/optimizer/adamuon.py
New AdaMuonOptimizer plus helpers zeropower_via_newtonschulz5 and _prepare_muon_momentum. Splits params by ndim (≥2D → AdaMuon path, 1D → Adam path), implements batched Newton–Schulz orthogonalization with bucketing, per-parameter momentum and v_buffer state, per-element EMA normalization, RMS-aligned/global scaling, and exposes AdaMuonOptimizer via package init.
Training integration
deepmd/pt/train/training.py
Imports and wires AdaMuonOptimizer into optimizer selection; extends get_opt_param to include weight_decay/momentum/adam_betas; restores optimizer state on restart; adds LambdaLR scheduler and warmup handling for AdaMuon steps.
Argument validation
deepmd/utils/argcheck.py
Adds AdaMuon variant to training args with fields: momentum, adam_beta1, adam_beta2, weight_decay, lr_adjust, lr_adjust_coeff (with defaults and PyTorch-only docs).
Tests
source/tests/pt/test_adamuon.py
New tests validating Newton–Schulz orthogonalization (shape/dtype/orthogonality), optimizer step for mixed 2D/1D params, state creation (momentum/v_buffer/exp_avg/exp_avg_sq), bucketing behavior, lr_adjust modes, weight decay semantics, closure handling, and state_dict save/load fidelity.

Sequence Diagram

sequenceDiagram
    participant Trainer
    participant AdaMuonOptimizer
    participant Segregator as Param<br/>Segregation
    participant AdamPath as Adam<br/>(1D)
    participant MuonPath as AdaMuon<br/>(≥2D)
    participant Buckets
    participant NS as Newton-Schulz
    participant Scale as Norm & Scale
    participant Params as Parameter<br/>Update

    Trainer->>AdaMuonOptimizer: step()
    AdaMuonOptimizer->>Segregator: separate params by ndim
    Segregator-->>AdamPath: 1D params
    Segregator-->>MuonPath: ≥2D params

    AdamPath->>AdamPath: update exp_avg/exp_avg_sq (FP32) and compute step
    AdamPath->>Params: apply scaled Adam updates

    MuonPath->>MuonPath: apply weight decay, update momentum buffers
    MuonPath->>Buckets: reshape to (M,N) and group by shape/device
    Buckets->>NS: stack sign matrices, run Newton–Schulz
    NS-->>Scale: return orthogonal directions (bf16)
    Scale->>Scale: update EMA v_buffer, compute per-element norm, apply RMS scaling and lr_adjust
    Scale->>Params: apply orthogonalized, scaled updates (-lr × update)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(pt): add AdaMuon optimizer' directly and clearly describes the main change: introducing a new AdaMuon optimizer to the PyTorch backend, which is evidenced by the new adamuon.py file, integration into the optimizer module, trainer configuration, and comprehensive test suite.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)

158-175: Pass collected adam_eps and nesterov parameters to AdaMuonOptimizer, or remove them from get_opt_param() if fixed

get_opt_param() collects adam_eps (default 1e-7) and nesterov (default True) into opt_param, but the AdaMuon branch (lines 709–720) does not pass these to the AdaMuonOptimizer constructor. They are functional parameters inside AdaMuonOptimizer (used in line 334 for the 1D parameter denominator and line 141 for momentum update logic), but they always use the optimizer's constructor defaults instead of the config values.

Either thread them through explicitly:

adam_eps=float(self.opt_param.get("adam_eps", 1e-8)),
nesterov=bool(self.opt_param.get("nesterov", True)),

or remove them from get_opt_param() to avoid a misleading API surface.

The shared weight_decay default of 0.001 is consistent and correct.

🧹 Nitpick comments (7)
deepmd/utils/argcheck.py (1)

3373-3427: Align AdaMuon training schema with actual optimizer parameters

The new AdaMuon variant under opt_type exposes muon_momentum, adam_beta1/2, weight_decay, lr_adjust, and lr_adjust_coeff, which matches most of what Trainer passes into AdaMuonOptimizer. However:

  • get_opt_param() in deepmd/pt/train/training.py also populates adam_eps and nesterov in opt_param, but those values are never forwarded into AdaMuonOptimizer’s constructor.
  • There is currently no public schema here for adam_eps or nesterov, so they are effectively non‑configurable and the opt_param entries are dead.

Consider either:

  • Exposing adam_eps and nesterov here and passing them through when constructing AdaMuonOptimizer, or
  • Dropping them from opt_param if you intend to keep them fixed at the optimizer defaults.

This keeps the configuration surface and runtime behavior in sync.

deepmd/pt/train/training.py (1)

798-804: AdaMuon integrated into main SGD path; multi-task still not supported

Including "AdaMuon" in the main if self.opt_type in ["Adam", "AdamW", "AdaMuon"]: block correctly reuses the Adam-style step and LR scheduling logic for the new optimizer.

Note that in multi-task mode self.opt_type becomes a dict, and this branch (as indicated by the nearby TODO) still doesn’t support per-task optimizers. If AdaMuon is not meant to be used with multi-task training yet, that’s fine; otherwise you may want an explicit check/error when multi_task=True and opt_type == "AdaMuon" is requested.

deepmd/pt/optimizer/adamuon.py (3)

40-110: Newton–Schulz helper looks sound; consider future configurability

The zeropower_via_newtonschulz5 implementation matches the described quintic Newton–Schulz scheme (bf16 compute, Frobenius normalization, tall-matrix transpose, fused GEMMs), and the guards on ndim and steps are reasonable.

If you later need more flexibility, two low-impact extensions would be:

  • Allowing an optional dtype/use_bfloat16 flag so callers can keep everything in FP32 on hardware where bf16 GEMMs are slow or unavailable.
  • Relaxing the steps >= 100 hard check to a warning or configurable cap.

Neither is blocking; the current implementation is fine as-is.


112-148: Momentum preparation helper: clarify Nesterov behavior

_prepare_muon_momentum cleanly encapsulates the momentum update and reshape logic, but the Nesterov path:

momentum_buffer.lerp_(grad, 1 - beta)
update = grad.lerp(momentum_buffer, beta) if nesterov else momentum_buffer

does not match the standard v = β v + g; u = g + β v Nesterov formulation. If this is intentionally following the AdaMuon reference implementation, a short comment would help avoid future “fixes” that change behavior. If the goal was standard Nesterov, consider switching to the canonical update.

Also, original_shape/reshape handling for >2D params is clear and works well.


256-407: AdaMuonOptimizer step: solid design with a small style nit

The overall step() logic is well-structured:

  • Clean separation of >=2D (AdaMuon) vs 1D (FP32 Adam) parameters.
  • Decoupled multiplicative weight decay only on >=2D params, as documented.
  • Per-parameter Adam state with bias correction via beta1_pow/beta2_pow.
  • Efficient bucketing by (rows, cols, device) plus batched Newton–Schulz on stacked sign matrices.
  • Per-element v_buffer EMA in FP32, followed by RMS-aligned scaling and shape-dependent adj_scale.

Two minor points:

  1. In the bucket loop, the update variable from the tuple is never read:

    for i, (p, update, orig_shape) in enumerate(bucket):
        state = self.state[p]
        orth_vec = orth_stacked[i].flatten().float()
        ...

    To satisfy linters and future readers, you can rename it to _update:

    for i, (p, _update, orig_shape) in enumerate(bucket):
        ...
  2. If you later want users to tune Newton–Schulz steps, ns_steps is already plumbed through the param group; exposing it via config (as you did with lr_adjust) would be straightforward.

Functionally this all looks correct and well tested given the dedicated test_adamuon.py suite.

source/tests/pt/test_adamuon.py (2)

9-14: Torch thread monkeypatch in tests: keep but simplify lambda signatures

The global monkeypatch of torch.set_num_interop_threads / torch.set_num_threads to no-ops is understandable to avoid thread reconfiguration warnings in the test environment.

To quiet linters and make the intent a bit clearer, you could drop the unused lambda parameters:

if torch_set_num_interop_threads is not None:
    torch.set_num_interop_threads = lambda *_, **__: None  # type: ignore[assignment]
if torch_set_num_threads is not None:
    torch.set_num_threads = lambda *_, **__: None  # type: ignore[assignment]

Functionally equivalent, just a small style tweak.


100-106: Use strict=True in zip in tests for extra safety

Both here:

for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)):
    ...

and later when zipping params1/params2 for state dict comparison, adding strict=True to zip would make mismatched lengths fail loudly instead of silently truncating:

for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params, strict=True)):
    ...
for p1, p2 in zip(params1, params2, strict=True):
    ...

This is purely a robustness/readability improvement; the current tests are otherwise very solid and comprehensive.

Also applies to: 372-383

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fe1662d and 3046a13.

📒 Files selected for processing (5)
  • deepmd/pt/optimizer/__init__.py
  • deepmd/pt/optimizer/adamuon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_adamuon.py
🧰 Additional context used
🧬 Code graph analysis (2)
deepmd/pt/train/training.py (2)
deepmd/pt/optimizer/adamuon.py (2)
  • AdaMuonOptimizer (151-409)
  • step (234-409)
deepmd/pd/train/training.py (2)
  • step (734-978)
  • warm_up_linear (586-590)
deepmd/pt/optimizer/__init__.py (1)
deepmd/pt/optimizer/adamuon.py (1)
  • AdaMuonOptimizer (151-409)
🪛 Ruff (0.14.10)
source/tests/pt/test_adamuon.py

12-12: Unused lambda argument: args

(ARG005)


12-12: Unused lambda argument: kwargs

(ARG005)


14-14: Unused lambda argument: args

(ARG005)


14-14: Unused lambda argument: kwargs

(ARG005)


100-100: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


372-372: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

deepmd/pt/optimizer/adamuon.py

77-77: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)


375-375: Loop control variable update not used within loop body

(B007)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (30)
  • GitHub Check: CodeQL analysis (python)
  • GitHub Check: Agent
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
🔇 Additional comments (1)
deepmd/pt/optimizer/__init__.py (1)

2-4: AdaMuon export wiring looks correct

Importing AdaMuonOptimizer here and adding it to __all__ cleanly exposes the new optimizer via deepmd.pt.optimizer. No issues from this change.

Also applies to: 12-12

@wanghan-iapcm wanghan-iapcm requested a review from iProzd January 8, 2026 05:13
@OutisLi OutisLi requested a review from wanghan-iapcm January 9, 2026 08:08
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/adamuon.py:
- Around line 40-109: The docstring for zeropower_via_newtonschulz5 should be
updated to explicitly state that the hard-coded coefficients (3.4445, -4.7750,
2.0315) are the "cursed quintic" from Muon/NanoGPT and that this variant
intentionally trades formal convergence guarantees for faster wall-clock
orthogonalization (they sum to ~0.701, so x=1 is not a fixed point and the
iteration can oscillate); revise the description to call this a
non-convergent/quasi-Newton-Schulz heuristic, warn users about possible
oscillation and lack of convergence guarantees, and mention that it operates in
bfloat16 and returns bfloat16 so numerical behavior may differ from a standard
convergent Newton-Schulz implementation (keep the rest of the function logic
unchanged).
🧹 Nitpick comments (3)
deepmd/pt/train/training.py (1)

707-725: Consider exposing lr_adjust and lr_adjust_coeff in get_opt_param.

The initialization correctly retrieves lr_adjust and lr_adjust_coeff from opt_param, but these parameters are not added to get_opt_param (lines 158-173). They will always use the defaults (10.0 and 0.2). If users should be able to configure these via the training config, consider adding them:

             # Muon/AdaMuon parameters
             "momentum": params.get("momentum", 0.95),
             "adam_beta1": params.get("adam_beta1", 0.9),
             "adam_beta2": params.get("adam_beta2", 0.95),
+            "lr_adjust": params.get("lr_adjust", 10.0),
+            "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),

Otherwise, if this is intentional to keep the interface simpler, no change is needed.

deepmd/pt/optimizer/adamuon.py (2)

373-407: Unused loop variable update in tuple unpacking.

At line 375, the loop variable update is unpacked but never used within the loop body. This is because the update tensor was already processed when building the stacked tensor at line 370. Consider renaming to _ to indicate it's intentionally unused:

♻️ Suggested fix
-                for i, (p, update, orig_shape) in enumerate(bucket):
+                for i, (p, _update, orig_shape) in enumerate(bucket):

313-325: Add a comment documenting reliance on torch._foreach_* private APIs.

While torch._foreach_mul_, torch._foreach_lerp_, and torch._foreach_mul are private APIs (explicitly marked as subject to change without warning), PyTorch's official Adam optimizer uses them for performance-critical batch operations. Consider adding an inline comment explaining this deliberate choice, noting that these private APIs are used for optimization purposes and may require updates if PyTorch changes their API in future versions.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3046a13 and 9a86678.

📒 Files selected for processing (3)
  • deepmd/pt/optimizer/adamuon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/utils/argcheck.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/pt/optimizer/adamuon.py
🧬 Code graph analysis (2)
deepmd/pt/optimizer/adamuon.py (2)
deepmd/pt/train/training.py (1)
  • step (772-1144)
source/tests/pt/test_adamuon.py (1)
  • closure (328-334)
deepmd/pt/train/training.py (2)
deepmd/pt/optimizer/adamuon.py (2)
  • AdaMuonOptimizer (151-409)
  • step (234-409)
deepmd/pd/train/training.py (1)
  • step (734-978)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/adamuon.py

77-77: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)


375-375: Loop control variable update not used within loop body

(B007)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (35)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
🔇 Additional comments (8)
deepmd/pt/train/training.py (3)

44-48: LGTM!

The import of AdaMuonOptimizer is correctly placed with other optimizer imports and aligns with the module structure.


166-172: LGTM!

The new optimizer parameters are properly documented with comments and their default values align with the AdaMuonOptimizer constructor.


796-814: LGTM!

Adding "AdaMuon" to the optimizer type check correctly integrates it with the existing training loop, including warmup handling, gradient clipping, and scheduler stepping.

deepmd/pt/optimizer/adamuon.py (5)

1-38: LGTM!

Well-structured module with clear documentation explaining the optimizer's purpose and key improvements. The reference to the source repository is helpful for traceability.


112-148: LGTM!

The momentum update and Nesterov handling are correctly implemented. The reshape logic properly handles higher-dimensional tensors for matrix-based processing.


205-231: LGTM!

Clean constructor following the standard PyTorch Optimizer pattern with well-documented defaults.


391-399: Numerical stability handling looks solid.

The implementation properly guards against division by zero with eps additions at critical points (v_buffer sqrt at line 394, norm division at line 399). The use of FP32 for accumulator buffers (v_buffer, exp_avg, exp_avg_sq) helps maintain precision.


151-203: Well-documented class with comprehensive docstring.

The class docstring clearly explains the dual update strategy (AdaMuon for >=2D, Adam for 1D params), key features, and parameter semantics including the lr_adjust mode behavior. This aids usability significantly.

Comment on lines +40 to +109
def zeropower_via_newtonschulz5(
G: torch.Tensor,
steps: int = 5,
eps: float = 1e-8,
) -> torch.Tensor:
"""
Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration.
Uses quintic Newton-Schulz iteration to compute the orthogonal component of the
input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T.
This implementation always performs Newton-Schulz in bfloat16 and returns a
bfloat16 tensor.
Parameters
----------
G : torch.Tensor
Input matrix to orthogonalize with shape (..., M, N).
steps : int
Number of Newton-Schulz iterations with default 5.
eps : float
Numerical stability epsilon for norm clamping with default 1e-8.
Returns
-------
torch.Tensor
Orthogonalized matrix in bfloat16 with same shape as input.
Raises
------
ValueError
If G has fewer than 2 dimensions.
ValueError
If steps >= 100 (guard for efficiency).
"""
# === Step 1. Validate ===
if G.ndim < 2:
raise ValueError("Input must have at least 2 dimensions (..., M, N).")
if steps >= 100:
raise ValueError("Number of steps must be less than 100 for efficiency.")

a, b, c = (3.4445, -4.7750, 2.0315)

# === Step 2. Cast to bf16 ===
X = G.to(dtype=torch.bfloat16)

# === Step 3. Transpose tall matrices ===
if X.size(-2) > X.size(-1):
X = X.mT

# === Step 4. Normalize Frobenius norm to at most 1 ===
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=eps)

# === Step 5. Newton-Schulz iterations with fused GEMM ===
for _ in range(steps):
A = X @ X.mT
# gram_update = b*A + c*(A@A) via addmm/baddbmm
# X = a*X + gram_update@X via addmm/baddbmm
if X.ndim == 2:
gram_update = torch.addmm(A, A, A, beta=b, alpha=c)
X = torch.addmm(X, gram_update, X, beta=a, alpha=1.0)
else:
gram_update = torch.baddbmm(A, A, A, beta=b, alpha=c)
X = torch.baddbmm(X, gram_update, X, beta=a, alpha=1.0)

# === Step 6. Transpose back if needed ===
if G.size(-2) > G.size(-1):
X = X.mT

return X
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

Newton-Schulz quintic iteration coefficients 3.4445 -4.7750 2.0315

💡 Result:

Those numbers are the "cursed quintic" Newton–Schulz polynomial used in the Muon/NanoGPT speedrun:

f(x) = 3.4445 x − 4.7750 x^3 + 2.0315 x^5.

It was chosen to rapidly inflate small singular values (faster wall‑clock orthogonalization) but it is non‑convergent — its coefficients sum to 0.701 ≠ 1, so x=1 is not a fixed point and the iteration can oscillate. [1][2]

Sources:

  • Modula docs (discussion of the cursed quintic and coefficients). [1]
  • ArXiv paper on Chebyshev‑optimized Newton‑Schulz (context on higher‑degree Newton–Schulz polynomials). [2]

Update docstring to disclose non-convergent trade-off and clarify these are the "cursed quintic" coefficients.

The coefficients (3.4445, -4.7750, 2.0315) are the "cursed quintic" from Muon/NanoGPT, designed for speed. However, these do not form a standard convergent Newton-Schulz iteration—they sum to 0.701 ≠ 1, so x=1 is not a fixed point, and the iteration can oscillate.

The docstring should clarify that this trades convergence guarantees for faster wall-clock orthogonalization, not that it is a standard Newton-Schulz method. Users should know they are using a non-convergent variant with intentional trade-offs.

🧰 Tools
🪛 Ruff (0.14.10)

77-77: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In @deepmd/pt/optimizer/adamuon.py around lines 40 - 109, The docstring for
zeropower_via_newtonschulz5 should be updated to explicitly state that the
hard-coded coefficients (3.4445, -4.7750, 2.0315) are the "cursed quintic" from
Muon/NanoGPT and that this variant intentionally trades formal convergence
guarantees for faster wall-clock orthogonalization (they sum to ~0.701, so x=1
is not a fixed point and the iteration can oscillate); revise the description to
call this a non-convergent/quasi-Newton-Schulz heuristic, warn users about
possible oscillation and lack of convergence guarantees, and mention that it
operates in bfloat16 and returns bfloat16 so numerical behavior may differ from
a standard convergent Newton-Schulz implementation (keep the rest of the
function logic unchanged).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/adamuon.py:
- Around line 44-113: The default eps=1e-8 in zeropower_via_newtonschulz5 is too
small for pure bfloat16 work—raise the default to at least 1e-3 (preferably
1e-1) or perform the normalization in float32 before casting to bfloat16:
compute norm = G.norm(dim=(-2,-1), keepdim=True).clamp(min=eps) in float32 and
divide G by that norm, then cast the result to torch.bfloat16 for the
Newton‑Schulz iterations; ensure the function signature default eps is updated
(or validated) to >=1e-3 if you choose the simpler default change.
🧹 Nitpick comments (1)
deepmd/pt/optimizer/adamuon.py (1)

379-379: Minor: Loop variable update is unused.

The loop unpacks (p, update, orig_shape) but only uses p and orig_shape. The update variable is unused because the orthogonalized sign matrix is retrieved via orth_stacked[i] instead.

While this is intentional (update was already used for sign computation in line 374), consider renaming to _ for clarity:

-for i, (p, update, orig_shape) in enumerate(bucket):
+for i, (p, _, orig_shape) in enumerate(bucket):
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9a86678 and 700bcbc.

📒 Files selected for processing (1)
  • deepmd/pt/optimizer/adamuon.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/pt/optimizer/adamuon.py
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/adamuon.py

81-81: Avoid specifying long messages outside the exception class

(TRY003)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


379-379: Loop control variable update not used within loop body

(B007)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (34)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (8)
deepmd/pt/optimizer/adamuon.py (8)

116-152: LGTM!

The momentum preparation logic is clean and correct. The in-place momentum update is properly documented, and the Nesterov logic follows standard conventions.


209-235: LGTM!

The initialization follows PyTorch optimizer conventions correctly, with sensible default values that match the documentation.


255-316: LGTM!

The parameter collection and state initialization logic is well-structured. Key strengths:

  • Clean separation of AdaMuon (>=2D) and Adam (1D) paths.
  • Adam state tensors maintained in FP32 for numerical stability.
  • Proper bias correction setup with beta power tracking.

318-341: LGTM!

The weight decay and Adam update implementation is solid:

  • Decoupled weight decay (AdamW-style) is correctly applied before updates.
  • Efficient batched operations via _foreach primitives.
  • Full FP32 computation path for Adam with final dtype cast ensures numerical stability.

343-370: LGTM!

The bucketing strategy and scaling computation are well-designed:

  • Efficient batching by shape and device for Newton-Schulz.
  • Dual lr_adjust behavior provides flexibility for different scaling strategies.
  • Pre-computed bucket-level constants avoid redundant computation.

371-376: LGTM!

The sign-stabilized orthogonalization is correctly implemented:

  • contiguous() call ensures proper memory layout before batched operations.
  • Stacking sign matrices enables efficient batched Newton-Schulz processing.

409-411: LGTM!

The final parameter update correctly:

  • Reshapes the orthogonalized vector back to original shape.
  • Casts to parameter dtype.
  • Applies with negative learning rate scaling.

387-398: The review comment's analysis is incomplete and misses a crucial mitigating factor.

The v_buffer is initialized to zeros as stated, and the element-wise normalization calculation on the first step is mathematically correct: orth_vec / sqrt((1-momentum) * orth_vec²) ≈ sign(orth_vec) / sqrt(1-momentum). However, the review fails to account for the RMS-aligned global scaling applied immediately after (line 400-403), which significantly mitigates the first-step effect:

  1. After element-wise normalization (line 398): scaling factor ≈ 4.47x per-element
  2. Immediately after, global RMS-aligned scaling (line 403): orth_vec / norm(orth_vec) * 0.2 * sqrt(m*n)
  3. The global scaling re-normalizes the vector and applies a 0.2 coefficient, substantially dampening the first-step update

Additionally, zero initialization of the second-moment buffer is the standard practice in reference AdaMuon implementations and matches the paper's algorithm. This is intentional and not an instability concern—the RMS-aligned global scaling (0.2 * sqrt(m*n) / norm) is the paper's mechanism for controlling first-step behavior.

No warm-up schedule or non-zero initialization is needed.

Likely an incorrect or invalid review comment.

Comment on lines +44 to +113
def zeropower_via_newtonschulz5(
G: torch.Tensor,
steps: int = 5,
eps: float = 1e-8,
) -> torch.Tensor:
"""
Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration.
Uses quintic Newton-Schulz iteration to compute the orthogonal component of the
input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T.
This implementation always performs Newton-Schulz in bfloat16 and returns a
bfloat16 tensor.
Parameters
----------
G : torch.Tensor
Input matrix to orthogonalize with shape (..., M, N).
steps : int
Number of Newton-Schulz iterations with default 5.
eps : float
Numerical stability epsilon for norm clamping with default 1e-8.
Returns
-------
torch.Tensor
Orthogonalized matrix in bfloat16 with same shape as input.
Raises
------
ValueError
If G has fewer than 2 dimensions.
ValueError
If steps >= 100 (guard for efficiency).
"""
# === Step 1. Validate ===
if G.ndim < 2:
raise ValueError("Input must have at least 2 dimensions (..., M, N).")
if steps >= 100:
raise ValueError("Number of steps must be less than 100 for efficiency.")

a, b, c = (3.4445, -4.7750, 2.0315)

# === Step 2. Cast to bf16 ===
X = G.to(dtype=torch.bfloat16)

# === Step 3. Transpose tall matrices ===
if X.size(-2) > X.size(-1):
X = X.mT

# === Step 4. Normalize Frobenius norm to at most 1 ===
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=eps)

# === Step 5. Newton-Schulz iterations with fused GEMM ===
for _ in range(steps):
A = X @ X.mT
# gram_update = b*A + c*(A@A) via addmm/baddbmm
# X = a*X + gram_update@X via addmm/baddbmm
if X.ndim == 2:
gram_update = torch.addmm(A, A, A, beta=b, alpha=c)
X = torch.addmm(X, gram_update, X, beta=a, alpha=1.0)
else:
gram_update = torch.baddbmm(A, A, A, beta=b, alpha=c)
X = torch.baddbmm(X, gram_update, X, beta=a, alpha=1.0)

# === Step 6. Transpose back if needed ===
if G.size(-2) > G.size(-1):
X = X.mT

return X
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

bfloat16 numerical precision mantissa bits decimal digits representation

💡 Result:

  • Bit layout: 1 sign, 8 exponent, 7 fraction (mantissa) bits. [1][2]
  • Effective significand precision (normalized numbers) = 7 + 1 (implicit leading 1) = 8 bits. [1][2]
  • Decimal digits of precision ≈ log10(2^8) ≈ 2.41 → about 2–3 significant decimal digits. [2][3]

References
[1] oneAPI / oneDNN bfloat16 documentation.
[2] Wikipedia — Bfloat16 floating-point format.
[3] Explanation of bfloat16 bit layout and precision (bitstring / format docs).


🌐 Web query:

Newton-Schulz iteration epsilon eps parameter convergence reduced precision float16 bfloat16

💡 Result:

Short guidance and practical choices.

  • Machine-unit roundoff (u): float16 (binary16) has 10-bit fraction → u ≈ 2^-10 ≈ 9.77×10^-4. bfloat16 has 7-bit fraction → u ≈ 2^-7 = 7.8125×10^-3. Use these as the floor for any stopping tolerance. [1][2]

  • Stopping criterion (recommended): monitor the relative residual R_k = ||I - A X_k|| / ||A|| (or ||I - X_k A||). Stop when R_k ≤ C·u where C∈[1,100] depending on safety margin and whether you use FP32 accumulation. For pure float16/bfloat16 math use C≈10–100; for mixed-precision (FP16/BF16 compute, FP32 accumulators or casts) you can use C≈1–10. Typical practical tolerances:

    • float16: eps ≈ 1e-3 … 1e-1 (safer: 1e-2).
    • bfloat16: eps ≈ 8e-3 … 8e-2 (safer: 1e-1).
      These bounds follow from the unit roundoff values and observed stability of Newton–Schulz under reduced precision. [1][3][4]
  • Iteration count: Newton–Schulz converges quadratically after normalization; in practice 3–6 iterations (often 3–5) suffice when normalized and using quintic/cubic variants used in ML orthogonalization. Stop earlier if R_k ≤ eps to avoid chasing rounding noise. [3][4]

  • Practical stability tips

    • Always pre-normalize X0 (scale A so spectral norm ≤ 1) — this is essential for contractive Newton–Schulz behavior. Use mixed precision for the normalization constant if available. [3][5]
    • Prefer mixed precision: do mat-muls in FP16/BF16 but accumulate or compute residuals in FP32 when possible (reduces effective error and lets you set smaller eps). [2][6]
    • If only pure 16-bit is available, increase eps (looser tolerance) and/or keep a few critical buffers (scaling, residual) in FP32. [1][2][3]
    • If you need a guaranteed final FP32-level accuracy, perform a final correction or one Newton step in FP32 with the 16-bit result as initial guess. [4][5]

References

  • FP formats and bit-allocation (float16, bfloat16): Intel / ML systems summaries. [1][2]
  • Newton–Schulz convergence and low-precision usage (MDPI study, Muon/nanochat discussions on ML orthogonalization and practical iterations). [3][4]
  • Mixed-precision advice and TensorRT/BFloat16 practice notes. [5][6]

If you want, I can: (a) compute exact unit roundoffs and suggested numeric eps for your hardware, (b) suggest pseudocode for a robust stopping check using mixed precision, or (c) tune eps for a specific matrix condition number — tell me which.


🏁 Script executed:

cat -n deepmd/pt/optimizer/adamuon.py | sed -n '44,113p'

Repository: deepmodeling/deepmd-kit

Length of output: 2782


Epsilon value is inappropriate for bfloat16 computation—use at least 1e-3, preferably 1e-1.

The function performs all Newton-Schulz iterations in pure bfloat16 (line 88) but defaults to eps=1e-8 (line 47). However, bfloat16 has only ~7-8 bits of mantissa precision and a machine unit roundoff of approximately 7.8×10⁻³. The default eps=1e-8 is ~1000× smaller than this roundoff, making it ineffective for norm clamping at line 95.

For pure bfloat16 Newton-Schulz iterations, recommended epsilon values are 1e-3 to 1e-1 (with 1e-1 being safer). Either:

  1. Increase the default eps parameter to at least 1e-3, or
  2. Use mixed-precision normalization (perform the norm and clamping in float32 before casting to bfloat16).

This affects convergence stability and the effectiveness of the normalization guard against division by zero.

🧰 Tools
🪛 Ruff (0.14.10)

81-81: Avoid specifying long messages outside the exception class

(TRY003)


83-83: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In @deepmd/pt/optimizer/adamuon.py around lines 44 - 113, The default eps=1e-8
in zeropower_via_newtonschulz5 is too small for pure bfloat16 work—raise the
default to at least 1e-3 (preferably 1e-1) or perform the normalization in
float32 before casting to bfloat16: compute norm = G.norm(dim=(-2,-1),
keepdim=True).clamp(min=eps) in float32 and divide G by that norm, then cast the
result to torch.bfloat16 for the Newton‑Schulz iterations; ensure the function
signature default eps is updated (or validated) to >=1e-3 if you choose the
simpler default change.

@codecov
Copy link

codecov bot commented Jan 9, 2026

Codecov Report

❌ Patch coverage is 93.43066% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.96%. Comparing base (fe1662d) to head (700bcbc).
⚠️ Report is 6 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt/optimizer/adamuon.py 95.31% 6 Missing ⚠️
deepmd/pt/train/training.py 57.14% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5130      +/-   ##
==========================================
- Coverage   82.15%   81.96%   -0.20%     
==========================================
  Files         709      713       +4     
  Lines       72468    73021     +553     
  Branches     3616     3616              
==========================================
+ Hits        59535    59849     +314     
- Misses      11769    12009     +240     
+ Partials     1164     1163       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant