-
Notifications
You must be signed in to change notification settings - Fork 584
feat(pt): add HybridMuonOptimizer #5149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds a new HybridMuonOptimizer (with Newton–Schulz zeropower), exposes it in the public API, integrates it into training/config validation, and introduces unit tests for routing, updates, dtype handling, and state persistence. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant HybridMuonOptimizer
participant Router
participant AdamPath
participant MuonPath
Trainer->>HybridMuonOptimizer: step(grads)
HybridMuonOptimizer->>Router: classify params (first step)
Router-->>HybridMuonOptimizer: partitions (adam_1d, adam_matrix, muon_params)
HybridMuonOptimizer->>AdamPath: update 1D params (exp_avg, exp_avg_sq)
HybridMuonOptimizer->>AdamPath: update small 2D matrices (Adam fallback)
HybridMuonOptimizer->>MuonPath: update large 2D matrices (momentum, NS orthogonalize)
MuonPath->>MuonPath: apply lr_adjust, weight decay, per-bucket scaling
HybridMuonOptimizer-->>Trainer: updated params / state
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/muon.py:
- Around line 466-544: raw_deltas are float32 but adam_matrix_params may be
bf16/fp16 causing dtype-mismatch on torch._foreach_add_; before calling
torch._foreach_add_ convert each delta to its corresponding param's dtype/device
(e.g. delta = delta.to(param.dtype).to(param.device) or delta =
delta.type_as(param))—do this after clipping/scaling and then call
torch._foreach_add_(adam_matrix_params, casted_deltas) so updates match
parameter dtypes.
In @source/tests/pt/test_muon.py:
- Around line 21-56: The test should skip on devices that don't support bfloat16
matmul: in test_orthogonalization (and optionally test_shape_and_dtype) probe
bf16 matmul support by attempting a small BF16 matmul on self.device inside a
try/except (e.g., create two tiny tensors with dtype=torch.bfloat16 and call
.matmul or torch.matmul) and call self.skipTest with an explanatory message if
it raises or is unsupported; use self.skipTest rather than asserting so CI
quietly skips environments where zeropower_via_newtonschulz5's BF16 path cannot
run reliably.
🧹 Nitpick comments (3)
deepmd/utils/argcheck.py (1)
3399-3526: Config surface matches the runtime wiring; minor doc/UX nits.
- Good:
Muonexposesmin_2d_dim,lr_adjust,lr_adjust_coeff, and Adam betas consistent withdeepmd/pt/train/training.py.- Consider clarifying in docs that
muon_momentumis an alias for both AdaMuon/Muon within their respective opt_type blocks (to reduce confusion).source/tests/pt/test_muon.py (1)
80-82: Preferzip(..., strict=True)in tests. Avoids silently ignoring length mismatches.Proposed tweak
- for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)): + for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params, strict=True)): self.assertFalse(torch.allclose(p, init_p), f"Parameter {i} did not change")deepmd/pt/optimizer/muon.py (1)
77-89: Consider wrappingtorch.compile()in try/except for robustness on unsupported devices or graph patterns. While the Newton-Schulz functions use standard operations that should compile reliably, adding a fallback allows graceful degradation if compilation fails on certain hardware configurations or edge cases:try: return torch.compile(fn, fullgraph=True, dynamic=True) except Exception: return fnThis is particularly useful since
fullgraph=Truecan fail hard rather than gracefully degrade. Given PyTorch 2.7+ is required (wheretorch.compileis stable), the redundanthasattr(torch, "compile")check can be simplified or removed.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pt/optimizer/__init__.pydeepmd/pt/optimizer/muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_muon.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:11.377Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
📚 Learning: 2026-01-10T04:28:11.377Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:11.377Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
Applied to files:
deepmd/pt/optimizer/muon.pydeepmd/pt/train/training.pysource/tests/pt/test_muon.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pt/train/training.py
🧬 Code graph analysis (3)
deepmd/pt/optimizer/muon.py (2)
deepmd/pt/train/training.py (1)
step(799-1171)source/tests/pt/test_adamuon.py (1)
closure(328-334)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/muon.py (1)
MuonOptimizer(226-646)
source/tests/pt/test_muon.py (1)
deepmd/pt/optimizer/muon.py (3)
MuonOptimizer(226-646)zeropower_via_newtonschulz5(159-186)step(379-646)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/muon.py
186-186: Avoid specifying long messages outside the exception class
(TRY003)
216-216: Avoid specifying long messages outside the exception class
(TRY003)
308-308: Avoid specifying long messages outside the exception class
(TRY003)
source/tests/pt/test_muon.py
80-80: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🔇 Additional comments (6)
deepmd/pt/optimizer/__init__.py (1)
11-15: Public export wiring looks correct.MuonOptimizeris imported and added to__all__, enablingfrom deepmd.pt.optimizer import MuonOptimizer.deepmd/pt/train/training.py (4)
44-49: Import + optimizer exposure is consistent with new public API.
159-177: Nice hardening: provide LKF defaults + Muon knobs inget_opt_param. Prevents missing-key crashes when LKF-specific keys aren’t present.
720-754: Muon optimizer integration looks correct; ensure state dict restore is covered.min_2d_dimis passed only for Muon (as expected), and scheduler + resume path are consistent with Adam/AdaMuon branches.
823-829: LR display/scheduler path correctly includes "Muon".deepmd/pt/optimizer/muon.py (1)
65-75: Keep NS coefficients as-is (they match standard references).NS_COEFF_A/B/Cmatch the canonical (3.4445, -4.7750, 2.0315). Based on learnings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds a new Muon optimizer to the DeePMD-kit PyTorch backend. Muon applies Newton-Schulz orthogonalization to gradients before using momentum, resulting in orthogonalized updates for weight matrices. The optimizer uses different update strategies based on parameter dimensionality: Muon with Newton-Schulz for >=2D parameters, and Adam for 1D parameters (biases, norms).
Changes:
- Implemented MuonOptimizer with Newton-Schulz orthogonalization algorithm
- Added configuration support for Muon optimizer parameters in argcheck
- Integrated Muon optimizer into the training pipeline with scheduler support
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/pt/optimizer/muon.py | New file implementing MuonOptimizer with Newton-Schulz orthogonalization for >=2D params and Adam for 1D params |
| deepmd/pt/optimizer/init.py | Added MuonOptimizer to module exports |
| deepmd/utils/argcheck.py | Added Muon optimizer configuration with parameters for momentum, Adam betas, weight decay, lr_adjust, and min_2d_dim |
| deepmd/pt/train/training.py | Integrated Muon optimizer initialization, parameter extraction, and scheduler setup; minor string formatting improvements |
| source/tests/pt/test_muon.py | Comprehensive test suite covering Newton-Schulz orthogonalization, optimizer step behavior, parameter routing, weight decay, lr_adjust modes, and state dict handling |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5149 +/- ##
==========================================
- Coverage 81.95% 81.93% -0.02%
==========================================
Files 713 714 +1
Lines 72985 73277 +292
Branches 3617 3617
==========================================
+ Hits 59812 60043 +231
- Misses 12010 12072 +62
+ Partials 1163 1162 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/pt/train/training.py (2)
720-753: AdaMuon does not initializeself.schedulerbut step() requires it.The
AdaMuonbranch (lines 720-732) creates the optimizer without initializingself.scheduler, but thestep()method at lines 824 and 841 assumesself.schedulerexists for bothAdaMuonandHybridMuon. This will cause anAttributeErrorat runtime when using AdaMuon.Add scheduler initialization to the
AdaMuonbranch:if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), )
159-177: Config-schema mismatch: LKF preference parameters are code-accessible but not schema-exposed.
get_opt_param()readskf_start_pref_e,kf_limit_pref_e,kf_start_pref_f, andkf_limit_pref_f(lines 162–165) and these are actively used to calculate preference weights during training (in the step method for EnergyStdLoss). However, the LKF schema indeepmd/utils/argcheck.pyonly defineskf_blocksize. Users cannot configure these preference parameters via the config file because schema validation will not recognize them.
🤖 Fix all issues with AI agents
In @source/tests/pt/test_hybrid_muon.py:
- Around line 107-109: The test loop uses zip(model.parameters(),
initial_params) which can silently truncate if lengths differ; update the loop
in test_hybrid_muon.py to use zip(model.parameters(), initial_params,
strict=True) so mismatched lengths raise an error (requires Python 3.10+),
keeping the same enumerate and assertion (i, p, init_p identifiers unchanged).
🧹 Nitpick comments (5)
deepmd/pt/optimizer/hybrid_muon.py (4)
88-100: Fix type annotations:callableis not a type.
Using the builtincallablein annotations is non-idiomatic and breaks type checking; usecollections.abc.Callable(ortyping.Callable) instead.Proposed diff
@@ -from typing import ( +from typing import ( TYPE_CHECKING, Any, ) @@ -if TYPE_CHECKING: +if TYPE_CHECKING: from collections.abc import ( + Callable, Iterable, ) @@ def _maybe_compile( - fn: callable, -) -> callable: + fn: "Callable[..., Any]", +) -> "Callable[..., Any]": @@ def step( self, - closure: callable | None = None, + closure: "Callable[[], torch.Tensor]" | None = None, ) -> torch.Tensor | None:Also applies to: 389-393
76-86: Consider aligningNS_EPSwith AdaMuon’s Newton–Schulz epsilon for consistency.
Coefficients match the standard (good), butNS_EPS=1e-7differs from the AdaMuon implementation used elsewhere in this repo (and the learned “don’t change eps” guidance). If the change is intentional, a short comment explaining why HybridMuon needs a different epsilon would help. Based on learnings, keep constants consistent unless there’s a measured reason.Also applies to: 120-128, 154-162
335-388: Static routing: please document that routing won’t change after the firststep().
This is probably fine, but it’s worth calling out explicitly because parameter freezing/unfreezing or adding param groups mid-training won’t be reflected after_routing_built=True.
477-555: Small-2D Adam fallback: avoid hard-coded magic caps (or expose them).
max_rel_change=0.05,abs_floor=1e-3*sqrt(numel), andmin(lr_adjust_coeff, 0.1)are important stability knobs but currently “hidden”. Consider making them constants (at least) or args so behavior is tunable and testable.source/tests/pt/test_hybrid_muon.py (1)
15-38: BF16 gating is reasonable, but import-time probing can be a little heavy.
Not a blocker, but if this ever becomes flaky on CI, consider moving the probe intosetUpClassto avoid side effects on import.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pt/optimizer/__init__.pydeepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-10T04:28:18.703Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
Applied to files:
deepmd/pt/optimizer/hybrid_muon.pysource/tests/pt/test_hybrid_muon.pydeepmd/pt/train/training.py
📚 Learning: 2026-01-10T04:29:25.299Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:44-113
Timestamp: 2026-01-10T04:29:25.299Z
Learning: The epsilon parameter (eps=1e-8) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` follows the standard PyTorch AdaMuon reference implementations (which typically use eps=1e-7) and should not be modified for consistency with reference implementations, despite being smaller than bfloat16's machine epsilon.
Applied to files:
deepmd/pt/optimizer/hybrid_muon.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pt/train/training.py
🧬 Code graph analysis (4)
deepmd/pt/optimizer/hybrid_muon.py (1)
source/tests/pt/test_adamuon.py (1)
closure(328-334)
source/tests/pt/test_hybrid_muon.py (2)
deepmd/pt/optimizer/hybrid_muon.py (1)
zeropower_via_newtonschulz5(170-197)source/install/build_tf.py (1)
env(829-839)
deepmd/pt/optimizer/__init__.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
HybridMuonOptimizer(237-657)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
HybridMuonOptimizer(237-657)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/hybrid_muon.py
197-197: Avoid specifying long messages outside the exception class
(TRY003)
227-227: Avoid specifying long messages outside the exception class
(TRY003)
319-319: Avoid specifying long messages outside the exception class
(TRY003)
source/tests/pt/test_hybrid_muon.py
32-32: Consider moving this statement to an else block
(TRY300)
107-107: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (6)
deepmd/utils/argcheck.py (2)
3399-3407: AdaMuon momentum alias looks good for backward compatibility.
Addingalias=["muon_momentum"]should help config migrations without breaking existingmomentumusage.
3454-3532: Documentation claim abouttorch.optim.Muonis accurate—no changes needed.PyTorch 2.7+ (the minimum supported version for GPU, and 2.8+ for CPU per pyproject.toml) includes
torch.optim.Muon, which is indeed designed to optimize only 2D parameters via Newton-Schulz orthogonalization. The doc statement correctly distinguishes DeePMD's HybridMuon (supporting both 1D and 2D parameters) from PyTorch's Muon (2D only), and is not misleading.deepmd/pt/optimizer/__init__.py (1)
1-20: Public export wiring looks correct.
Import +__all__update is consistent with training.py usage.deepmd/pt/train/training.py (2)
733-746: HybridMuon integration looks consistent with the new optimizer interface.
Passinglr_adjust/lr_adjust_coeff/min_2d_dimmatchesHybridMuonOptimizer.__init__, and state_dict loading mirrors Adam/AdaMuon flows.
639-641: Logging format changes look fine.
The single-line warning/info logs are clearer and keep log parsing simple.Also applies to: 1564-1566
source/tests/pt/test_hybrid_muon.py (1)
40-228: Test coverage looks solid for routing + state_dict round-trip.
The separation checks (momentum_buffervsexp_avg/exp_avg_sq) andmin_2d_dimfallback test should catch most integration regressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/hybrid_muon.py:
- Around line 684-724: The muon update computes an `orth` delta (via
_zeropower_via_newtonschulz5_2d and _zeropower_via_newtonschulz5_3d) in bf16 but
applies it in-place to params of potentially different dtypes causing runtime
errors; cast the delta(s) to the target parameter dtype before calling in-place
add_. Specifically, in the single-matrix branch cast `orth`/`delta` to
`entry["param"].dtype` before `entry["param"].add_`, and in the batched branch
cast each `orth[i]` (or the slice used) to `params[i].dtype` (or call
.to(params[i].dtype)) before `params[i].add_`.
- Around line 77-99: The _maybe_compile function currently calls
torch.compile(fn, fullgraph=True, dynamic=True) at import time and can raise
exceptions that break imports; wrap the torch.compile call in a try/except that
catches Exception (or RuntimeError) and returns the original fn on any
compilation failure, while preserving the existing default_device check and
behavior; ensure the fallback logs or silently ignores the compile error and
returns fn so that _maybe_compile and functions wrapped by it (e.g., any callers
of _maybe_compile) remain usable when torch.compile is unavailable or fails.
🧹 Nitpick comments (2)
deepmd/utils/argcheck.py (1)
3454-3542: HybridMuon arg schema matches the intended routing knobs; one doc nit: “Nesterov” formula wording is slightly off.
You describem_t = beta*m_{t-1} + (1-beta)*g_t(momentum EMA) and later apply a Nesterov-style lookahead; consider rewording to avoid implying the EMA itself is “Nesterov”.deepmd/pt/train/training.py (1)
159-178:get_opt_param()now always includes LKF + Muon knobs; please sanity-check you don’t silently accept misspelled keys.
Not a blocker, but when configs evolve, it’s easy to carry dead/typoed fields unnoticed.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-10T04:28:18.703Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
Applied to files:
deepmd/pt/train/training.pydeepmd/pt/optimizer/hybrid_muon.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pt/train/training.py
📚 Learning: 2026-01-10T04:29:25.299Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:44-113
Timestamp: 2026-01-10T04:29:25.299Z
Learning: The epsilon parameter (eps=1e-8) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` follows the standard PyTorch AdaMuon reference implementations (which typically use eps=1e-7) and should not be modified for consistency with reference implementations, despite being smaller than bfloat16's machine epsilon.
Applied to files:
deepmd/pt/optimizer/hybrid_muon.py
🧬 Code graph analysis (2)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
HybridMuonOptimizer(236-725)
deepmd/pt/optimizer/hybrid_muon.py (2)
deepmd/pt/train/training.py (1)
step(801-1173)source/tests/pt/test_adamuon.py (1)
closure(328-334)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/hybrid_muon.py
196-196: Avoid specifying long messages outside the exception class
(TRY003)
226-226: Avoid specifying long messages outside the exception class
(TRY003)
324-324: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (7)
deepmd/utils/argcheck.py (1)
3396-3450: AdaMuon:muon_momentumalias looks good; please ensure it’s documented consistently in user docs/examples.
This improves backward/forward compatibility for config naming (esp. when users think in “Muon momentum” rather than generic “momentum”).deepmd/pt/optimizer/hybrid_muon.py (3)
101-197: Newton–Schulz kernels: coeffs look consistent with reference implementations (good).
One thing to verify: using a singleEPS = 1e-7for both NS norm clamp and Adam epsilon may be intentional, but it’s a behavior change vs AdaMuon-style eps choices. Based on learnings, keep coeffs unchanged; please double-check eps expectations.
341-407: Static routing is sensible, but it assumes parameter ranks/shapes stay stable across training.
That’s probably fine for this codebase; just be aware it will misroute if someone swaps modules/params mid-training (rare).
443-627: The implementation looks sound. The bias-correction formula matches PyTorch's Adam (state["beta1_pow"] and state["beta2_pow"] are cumulative products β^t, giving bias_corr = 1 − β^t), and there's existing test coverage for state_dict roundtrip inTestHybridMuonOptimizerStateDict.test_state_dict_save_load. The use oftorch._foreach_*is acceptable for the supported torch version range (≥2.7), and this pattern is established elsewhere in the codebase (e.g.,adamuon.py). No action needed.deepmd/pt/train/training.py (3)
44-49: Training import + optimizer registry wiring forHybridMuonOptimizerlooks consistent.
825-844: Step-loop inclusion (HybridMuonalongsideAdam/AdaMuon) is correct for scheduler LR tracking.
734-755: HybridMuon init matches argcheck + optimizer signature;muon_2d_onlyandmin_2d_dimare documented in class docstring and argcheck with test examples provided. Remove the unrelated bf16 dtype query from this review.
fix(pt): Muon bug fix feat&fix(pt): Muon add bf16 support feat(pt): use tf32 for Muon fix(pt): Use 1e-8 for Muon feat(pt): Update Muon fix(pt): use the same lr for adam inside Muon feat(pt): add match_rms for Muon feat(pt): adjust Muon feat(pt): Update Muon (cherry picked from commit 9b4e63d)
(cherry picked from commit 46fcb7d)
(cherry picked from commit 1dd737f)
(cherry picked from commit c6f7e9f)
Changes: 1. Remove dtype conversion: NS output (bfloat16) now directly applied to parameters, matching torch.optim.Muon behavior where PyTorch handles mixed precision automatically. 2. Add muon_2d_only parameter (default True): When True, only 2D parameters use Muon; >2D parameters use Adam without weight decay. This matches PyTorch's official torch.optim.Muon which only supports 2D matrices. 3. Merge NS_EPS and ADAM_EPS into single EPS constant (both 1e-7). 4. Update dtype documentation to reflect actual behavior: - NS output (bfloat16) directly applied to parameters - Muon momentum buffer follows gradient dtype (not param dtype) 5. Update weight_decay docstring from ">=2D params" to "Muon-routed parameters" for accuracy with muon_2d_only=True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In @deepmd/pt/optimizer/hybrid_muon.py:
- Around line 87-99: The _maybe_compile function currently calls torch.compile
directly which can raise exceptions; update it to call torch.compile(fn,
fullgraph=True, dynamic=True) inside a try/except that catches Exception (or
RuntimeError) and on failure logs or silently ignores the error and returns the
original fn so compilation failures fall back to eager execution; reference the
_maybe_compile function and torch.compile call and ensure the except path
returns fn (and optionally logs the exception via the module logger or
warnings).
- Around line 443-724: The foreach calls fail on heterogeneous devices/dtypes;
bucket tensors by (device,dtype) before calling torch._foreach_* and run the
foreach per-bucket (or fallback to per-parameter ops) for muon_params_for_decay
(torch._foreach_mul_), for muon_grads and muon_momentum_buffers
(torch._foreach_lerp_ and torch._foreach_lerp) and for
adam_matrix_params/raw_deltas when calling torch._foreach_norm and torch.stack;
specifically, group lists into buckets keyed by (tensor.device, tensor.dtype)
(like the existing buckets for Newton-Schulz), then call the corresponding
torch._foreach_* on each bucket's sublists and replace the original single-call
sites (references: muon_params_for_decay, muon_grads, muon_momentum_buffers,
adam_matrix_params, raw_deltas, and the torch._foreach_* invocations) so every
foreach sees homogeneous device/dtype inputs and you avoid mixed-precision
crashes.
🧹 Nitpick comments (5)
deepmd/pt/optimizer/hybrid_muon.py (1)
77-85: Align/justifyEPSacross AdaMuon vs HybridMuon (currently1e-7here).
You kept the standard NS coefficients (good).EPSdiffers fromdeepmd/pt/optimizer/adamuon.py(which historically used1e-8per repo learnings). If the difference is intentional (e.g., matching official Muon), please add a short comment explaining why HybridMuon diverges so users don’t “fix” it later. Based on learnings, the coefficients should stay unchanged.Also applies to: 119-121
deepmd/pt/train/training.py (1)
159-178: Makeget_opt_param()resilient to alias keys (if configs bypass normalization).
Argcheck introduces aliases likemuon_momentum/muon_min_2d_dim, butget_opt_param()reads onlymomentum/min_2d_dim. If a caller provides raw configs without running the normalizer, HybridMuon/AdaMuon may silently ignore the alias values.Proposed fix
- "momentum": params.get("momentum", 0.95), + "momentum": params.get("momentum", params.get("muon_momentum", 0.95)), @@ - "min_2d_dim": params.get("min_2d_dim", 1), + "min_2d_dim": params.get("min_2d_dim", params.get("muon_min_2d_dim", 1)),deepmd/utils/argcheck.py (1)
3420-3476: Docs: mentionlr_adjust_coeffalso affects small-2D Adam fallback (current implementation).
InHybridMuonOptimizer,lr_adjust_coeffis “dual-purpose” (match-RMS scaling and matrix-fallback LR scaling viamin(lr_adjust_coeff, 0.1)). The schema doc currently describes only match-RMS scaling. Consider adding one sentence to prevent surprise when users tune it.Also applies to: 3478-3566
source/tests/pt/test_hybrid_muon.py (2)
29-34: Consider moving the success return to anelseblock.The linter suggests placing the success path in an
elseblock for clearer control flow separation between success and error handling.♻️ Suggested refactor
try: a = torch.randn(4, 4, dtype=torch.bfloat16, device=device) _ = torch.mm(a, a.T) - return True except (RuntimeError, TypeError): return False + else: + return True
107-108: Addstrict=Truetozip()for defensive checking.While both iterables originate from the same model making length mismatch unlikely, adding
strict=Trueprovides early failure detection if the iteration logic ever changes.♻️ Suggested fix
- for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)): + for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params, strict=True)): self.assertFalse(torch.allclose(p, init_p), f"Parameter {i} did not change")
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pt/optimizer/__init__.pydeepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-10T04:28:18.703Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
Applied to files:
deepmd/pt/optimizer/hybrid_muon.pysource/tests/pt/test_hybrid_muon.pydeepmd/pt/train/training.py
📚 Learning: 2026-01-10T04:29:25.299Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:44-113
Timestamp: 2026-01-10T04:29:25.299Z
Learning: The epsilon parameter (eps=1e-8) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` follows the standard PyTorch AdaMuon reference implementations (which typically use eps=1e-7) and should not be modified for consistency with reference implementations, despite being smaller than bfloat16's machine epsilon.
Applied to files:
deepmd/pt/optimizer/hybrid_muon.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pt/train/training.py
🧬 Code graph analysis (4)
deepmd/pt/optimizer/__init__.py (3)
deepmd/pt/optimizer/hybrid_muon.py (1)
HybridMuonOptimizer(236-725)deepmd/pt/optimizer/KFWrapper.py (1)
KFOptimizerWrapper(13-140)deepmd/pt/optimizer/LKF.py (1)
LKFOptimizer(30-325)
deepmd/pt/optimizer/hybrid_muon.py (2)
deepmd/pt/train/training.py (1)
step(819-1191)source/tests/pt/test_adamuon.py (1)
closure(330-336)
source/tests/pt/test_hybrid_muon.py (1)
deepmd/pt/optimizer/hybrid_muon.py (3)
HybridMuonOptimizer(236-725)zeropower_via_newtonschulz5(169-196)step(409-725)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
HybridMuonOptimizer(236-725)
🪛 Ruff (0.14.10)
deepmd/pt/optimizer/hybrid_muon.py
196-196: Avoid specifying long messages outside the exception class
(TRY003)
226-226: Avoid specifying long messages outside the exception class
(TRY003)
324-324: Avoid specifying long messages outside the exception class
(TRY003)
source/tests/pt/test_hybrid_muon.py
32-32: Consider moving this statement to an else block
(TRY300)
107-107: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (20)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
🔇 Additional comments (8)
deepmd/pt/optimizer/__init__.py (1)
5-20: LGTM: clean public API export forHybridMuonOptimizer.deepmd/pt/train/training.py (1)
752-773: HybridMuon integration looks consistent with Adam/AdaMuon scheduling and stepping.
Instantiation wiring (betas, lr_adjust*, muon_2d_only, min_2d_dim) + inclusion in the Adam-family step path is coherent.Also applies to: 843-861
deepmd/utils/argcheck.py (1)
3420-3431: Alias normalization is already handled correctly by the dargs framework.The config is explicitly normalized via
deepmd.utils.argcheck.normalize()in the entrypoint (deepmd/pt/entrypoints/main.py) beforeTrainerinstantiation. This function uses dargs'normalize_value()method, which automatically handles alias-to-canonical key mapping for nestedArgumentandVariantschemas recursively. The optimizer code intraining.pyaccesses parameters using only canonical names ("momentum","min_2d_dim"), not the alias names. All entrypoint code paths follow the normalize → Trainer flow, so un-normalized config reaching the optimizer is not a concern.source/tests/pt/test_hybrid_muon.py (5)
1-13: LGTM!License header and imports are appropriate for the test module.
40-82: LGTM!The Newton-Schulz orthogonalization tests provide good coverage:
- Orthogonality verification with appropriate tolerances for bf16 precision.
- Shape and dtype preservation for 2D and 3D inputs.
- Error handling for invalid 1D input.
110-191: LGTM!The optimizer tests provide solid coverage of key behaviors:
- Weight decay effect verification.
- Correct routing of 2D params to Muon (momentum_buffer) and 1D to Adam (exp_avg/exp_avg_sq).
- Adam fallback for small 2D matrices based on
min_2d_dim.- Different
lr_adjustmodes producing distinct updates.Fixed seeds ensure reproducibility.
194-227: LGTM!State dict save/load test correctly verifies:
- State is properly serialized and deserialized.
- Both tensor values (via
torch.allclose) and scalar values are compared.- Per-parameter state integrity is maintained.
230-231: LGTM!Standard test entry point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)
739-772: Add scheduler setup for AdaMuon optimizer.The AdaMuon block (lines 739-751) lacks scheduler initialization, yet line 843 includes
"AdaMuon"in the condition that callsself.scheduler.get_last_lr()[0]at line 844. This will cause anAttributeErrorat runtime.Add the missing scheduler initialization following the HybridMuon pattern (lines 768-772):
if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), )Insert this block immediately after the AdaMuon optimizer initialization (after line 751).
🧹 Nitpick comments (2)
source/tests/pt/test_hybrid_muon.py (2)
15-34: LGTM! Robust BF16 support detection.The function handles CUDA capability checks gracefully and falls back to a practical matmul test for CPU. This addresses the CI robustness concern from previous reviews.
Minor style note: the static analyzer suggests moving
return True(line 32) to anelseblock for clarity, but this is optional.♻️ Optional: Move return to else block per TRY300
try: a = torch.randn(4, 4, dtype=torch.bfloat16, device=device) _ = torch.mm(a, a.T) - return True except (RuntimeError, TypeError): return False + else: + return True
84-193: Consider adding test coverage formuon_2d_only=False.All tests use the default
muon_2d_only=True. Adding a test withmuon_2d_only=Falsewould verify that ≥2D (e.g., 3D) parameters use Muon instead of Adam, exercising that routing path.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/train/training.pysource/tests/pt/test_hybrid_muon.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
📚 Learning: 2026-01-10T04:28:18.703Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:40-109
Timestamp: 2026-01-10T04:28:18.703Z
Learning: The coefficients (3.4445, -4.7750, 2.0315) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` are standard across AdaMuon implementations and should not be modified for consistency with the reference implementations.
Applied to files:
source/tests/pt/test_hybrid_muon.pydeepmd/pt/train/training.py
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pt/train/training.py
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
HybridMuonOptimizer(236-725)
🪛 Ruff (0.14.11)
source/tests/pt/test_hybrid_muon.py
32-32: Consider moving this statement to an else block
(TRY300)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
🔇 Additional comments (9)
deepmd/pt/train/training.py (5)
44-49: LGTM!The
HybridMuonOptimizerimport is correctly added alongside existing optimizer imports.
159-178: LGTM!The
get_opt_paramfunction correctly extracts parameters for the HybridMuon optimizer, includingmuon_2d_onlyandmin_2d_dimwhich are specific to HybridMuon as indicated in the external code snippet.
843-843: LGTM!Correctly includes
HybridMuonin the optimizer types that use the LR scheduler.
656-658: LGTM!The log message consolidation to single-line format improves readability.
1584-1586: LGTM!The log message format change is consistent with the other logging updates in this PR.
source/tests/pt/test_hybrid_muon.py (4)
40-82: LGTM!The Newton-Schulz orthogonalization tests appropriately validate:
- Approximate orthogonality with tolerances suitable for bf16 precision
- Shape preservation across 2D and batched 3D inputs
- Error handling for invalid 1D input
Based on learnings, the underlying NS coefficients (3.4445, -4.7750, 2.0315) are standard and should not be modified.
84-143: LGTM!Good coverage of optimizer fundamentals:
test_stepvalidates parameters change after optimizationtest_weight_decayconfirms decay reduces normstest_muon_adam_separationverifies the routing logic (Muon for 2D weights, Adam for 1D biases)The
strict=Truein zip() at line 108 addresses the previous review feedback.
145-193: LGTM!These tests validate key HybridMuon behaviors:
test_muon_adam_fallback_small_2d: Correctly tests themin_2d_dimthreshold causing fallback to Adamtest_lr_adjust_modes: Verifies that differentlr_adjustvalues produce different update behaviors
196-229: LGTM!The state dict persistence test properly verifies that optimizer state survives save/load cycles by comparing per-parameter state tensors.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.