-
Notifications
You must be signed in to change notification settings - Fork 585
refactor: unify learning rate schedulers with array API #5154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughRefactors learning-rate subsystem: adds BaseLR with warmup support and an abstract _decay_value/value contract, updates exponential and cosine schedulers, introduces stop_lr_ratio/num_steps, converts TF wrapper to LearningRateSchedule (dict-based), updates training/scheduler usage, adds validations, docs, examples, and tests. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant TFWrapper as LearningRateSchedule
participant BaseLR
participant NumpyFunc as numpy_function
Trainer->>TFWrapper: build(global_step, num_steps)
TFWrapper->>BaseLR: instantiate BaseLR(params + num_steps)
TFWrapper->>NumpyFunc: register callable -> BaseLR.value
Trainer->>NumpyFunc: runtime global_step -> call BaseLR.value(step)
NumpyFunc->>BaseLR: value(step)
alt step < warmup_steps
BaseLR->>BaseLR: compute warmup lr (linear interp)
else
BaseLR->>BaseLR: _decay_value(step - warmup_steps) -> exp/cosine
end
BaseLR-->>NumpyFunc: float lr
NumpyFunc-->>Trainer: lr tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: Repository UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (1)
🧰 Additional context used🧬 Code graph analysis (2)source/tests/pt/model/test_model.py (1)
source/tests/pd/model/test_model.py (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
🔇 Additional comments (11)
✏️ Tip: You can disable this entire section by setting Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/utils/learning_rate.py`:
- Around line 336-338: BaseLR.__init__ currently computes self.lr_min_factor =
self.stop_lr / self.start_lr which will raise ZeroDivisionError if start_lr is
zero; add an explicit validation at the start of BaseLR.__init__ that checks
start_lr is non-zero (and preferably > 0) and raise a clear ValueError like
"start_lr must be > 0" if it fails, so the invalid configuration is caught with
a descriptive message before any division occurs; update callers/tests if they
relied on a zero value.
🧹 Nitpick comments (5)
deepmd/tf/fit/polar.py (1)
859-867: LGTM! Docstring updated to reflect the renamed scheduler class.The documentation change from
LearningRateExptoLearningRateSchedulecorrectly aligns with the PR's refactoring of the learning rate scheduler API.Consider adding a type annotation to the function signature for consistency:
- def get_loss(self, loss: dict, lr) -> Loss: + def get_loss(self, loss: dict, lr: "LearningRateSchedule") -> Loss:This would also apply to
PolarFittingSeA.get_lossat line 618, which currently lacks parameter documentation entirely.source/tests/consistent/test_learning_rate.py (1)
74-78: Redundant skip check insidecompare_test_with_warmup_ref.The
skipTestguard on line 75-76 is redundant since all call sites already checkif self.warmup_step is not Nonebefore invoking this method. Consider removing the internal check or, if keeping it as defensive coding, note thatwarmup_refbeingNonewhilewarmup_stepis set would indicate a bug insetUp.♻️ Suggested simplification
def compare_test_with_warmup_ref(self, step: Array) -> None: - if self.warmup_ref is None: - self.skipTest("warmup not enabled") + assert self.warmup_ref is not None, "warmup_ref should be set when warmup_step is set" test = self.lr.value(step) np.testing.assert_allclose(self.warmup_ref, to_numpy_array(test), atol=1e-10)deepmd/pt/utils/utils.py (1)
229-237: Type hint mismatch:intis handled but not declared in signature.The function signature at line 230 declares
torch.Tensor | np.ndarray | float | None, but line 234 also handlesintinputs. Additionally, the@overloadsignatures (lines 221-227) don't cover the new input types (float,int,np.ndarray), which may cause type checker warnings for callers.Proposed fix to align type hints
`@overload` def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ... `@overload` def to_numpy_array(xx: None) -> None: ... + +@overload +def to_numpy_array(xx: np.ndarray) -> np.ndarray: ... + + +@overload +def to_numpy_array(xx: float | int) -> np.ndarray: ... + def to_numpy_array( - xx: torch.Tensor | np.ndarray | float | None, + xx: torch.Tensor | np.ndarray | float | int | None, ) -> np.ndarray | None:deepmd/pd/utils/utils.py (1)
250-253: Note: Pre-existing dead code.Line 251 provides a default value (
np.float64) when the key is not found, sopreccan never beNoneat line 252. The check at lines 252-253 is unreachable. This is pre-existing and unrelated to the current changes, but worth noting for future cleanup.deepmd/dpmodel/utils/learning_rate.py (1)
72-75: Potential type inconsistency when both stop_lr and stop_ratio are None.If validation in
argcheck.pyis bypassed (e.g., programmatic instantiation without going through argument normalization),self.stop_lrcould remainNone, which would cause issues downstream (e.g., line 246max(self.stop_lr, 1e-10)or line 338self.stop_lr / self.start_lr).Consider adding a defensive runtime check here, or ensure documentation clearly states that one of
stop_lrorstop_ratiomust be provided.♻️ Suggested defensive check
# === Step 1. Compute stop_lr from stop_ratio if needed === # Mutual exclusion validated in argcheck.py if stop_ratio is not None: self.stop_lr = start_lr * stop_ratio else: self.stop_lr = stop_lr # type: ignore[assignment] + + if self.stop_lr is None: + raise ValueError( + "Either stop_lr or stop_ratio must be provided" + )
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
deepmd/dpmodel/utils/learning_rate.pydeepmd/pd/train/training.pydeepmd/pd/utils/utils.pydeepmd/pt/train/training.pydeepmd/pt/utils/utils.pydeepmd/tf/fit/dipole.pydeepmd/tf/fit/dos.pydeepmd/tf/fit/ener.pydeepmd/tf/fit/fitting.pydeepmd/tf/fit/polar.pydeepmd/tf/train/trainer.pydeepmd/tf/utils/__init__.pydeepmd/tf/utils/learning_rate.pydeepmd/utils/argcheck.pysource/tests/consistent/test_learning_rate.pysource/tests/pd/model/test_model.pysource/tests/pd/test_lr.pysource/tests/pt/model/test_model.pysource/tests/pt/test_lr.pysource/tests/tf/test_lr.pysource/tests/universal/dpmodel/utils/test_learning_rate.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/pd/utils/utils.pydeepmd/tf/fit/fitting.py
📚 Learning: 2025-12-12T13:40:14.334Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-12-12T13:40:14.334Z
Learning: Run core tests with `pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame -v` to validate basic functionality
Applied to files:
source/tests/tf/test_lr.py
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR `#3905`.
Applied to files:
source/tests/pt/test_lr.py
🧬 Code graph analysis (11)
deepmd/pt/utils/utils.py (1)
deepmd/pt/model/network/network.py (1)
Tensor(34-35)
source/tests/tf/test_lr.py (2)
deepmd/dpmodel/utils/learning_rate.py (2)
LearningRateExp(170-278)value(123-166)deepmd/tf/utils/learning_rate.py (5)
LearningRateSchedule(20-123)value(102-123)base_lr(50-66)build(68-100)start_lr(38-47)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
BaseLR(25-166)value(123-166)
source/tests/pd/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (2)
LearningRateSchedule(20-123)start_lr(38-47)
deepmd/pt/train/training.py (2)
deepmd/pd/train/training.py (1)
step(722-963)deepmd/tf/utils/learning_rate.py (2)
value(102-123)start_lr(38-47)
source/tests/universal/dpmodel/utils/test_learning_rate.py (4)
deepmd/pd/utils/utils.py (3)
to_numpy_array(230-230)to_numpy_array(234-234)to_numpy_array(237-256)deepmd/pt/utils/utils.py (3)
to_numpy_array(222-222)to_numpy_array(226-226)to_numpy_array(229-249)deepmd/dpmodel/utils/learning_rate.py (3)
LearningRateCosine(282-374)LearningRateExp(170-278)value(123-166)deepmd/tf/utils/learning_rate.py (2)
start_lr(38-47)value(102-123)
deepmd/tf/train/trainer.py (1)
deepmd/tf/utils/learning_rate.py (1)
LearningRateSchedule(20-123)
deepmd/dpmodel/utils/learning_rate.py (2)
deepmd/tf/utils/learning_rate.py (1)
start_lr(38-47)deepmd/utils/argcheck.py (1)
register(185-215)
deepmd/tf/utils/__init__.py (1)
deepmd/tf/utils/learning_rate.py (1)
LearningRateSchedule(20-123)
source/tests/pt/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (2)
LearningRateSchedule(20-123)start_lr(38-47)
source/tests/pd/test_lr.py (2)
deepmd/tf/utils/learning_rate.py (4)
LearningRateSchedule(20-123)start_lr(38-47)base_lr(50-66)build(68-100)deepmd/dpmodel/utils/learning_rate.py (1)
LearningRateExp(170-278)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py
34-34: Avoid specifying long messages outside the exception class
(TRY003)
65-65: Avoid specifying long messages outside the exception class
(TRY003)
122-122: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/dpmodel/utils/learning_rate.py
40-40: Unused method argument: kwargs
(ARG002)
86-86: Avoid specifying long messages outside the exception class
(TRY003)
88-88: Avoid specifying long messages outside the exception class
(TRY003)
90-90: Avoid specifying long messages outside the exception class
(TRY003)
241-243: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/utils/argcheck.py
2506-2509: Avoid specifying long messages outside the exception class
(TRY003)
2511-2514: Avoid specifying long messages outside the exception class
(TRY003)
2542-2545: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (45)
source/tests/consistent/test_learning_rate.py (2)
56-68: LGTM!The test setup correctly initializes reference values for both regular steps and warmup steps, enabling cross-backend consistency validation.
83-104: LGTM!The test methods consistently validate learning rate values across PyTorch, array_api_strict, and JAX backends, with proper conditional warmup testing.
deepmd/tf/fit/ener.py (1)
852-866: LGTM!The docstring type annotation update to
LearningRateSchedulealigns with the broader refactoring in this PR.deepmd/tf/fit/fitting.py (1)
68-83: LGTM!The docstring type annotation update in the abstract
get_lossmethod correctly reflects the newLearningRateScheduletype used throughout the codebase.deepmd/tf/fit/dipole.py (1)
384-398: LGTM!The docstring type annotation update to
LearningRateScheduleis consistent with the base class and other fitting implementations.deepmd/tf/fit/dos.py (1)
651-668: LGTM!The docstring type annotation update to
LearningRateScheduleis consistent with the base class and other fitting implementations.source/tests/pt/model/test_model.py (2)
51-53: LGTM!The import change from
LearningRateExptoLearningRateSchedulecorrectly aligns with the refactored TensorFlow learning rate API. The test appropriately uses separate LR implementations for each backend being compared.
228-236: LGTM!The migration to dict-based
LearningRateScheduleconstruction is correct and provides all required parameters. This aligns with the unified API that accepts a configuration dictionary.deepmd/pd/utils/utils.py (1)
242-246: LGTM!The added handling for scalar inputs and numpy arrays mirrors the changes in the PT backend, maintaining consistency across the codebase.
deepmd/tf/utils/__init__.py (2)
20-27: LGTM!The
__all__list correctly reflects the updated export, maintaining consistency between the module's imports and its public interface.
9-11: This concern is not applicable. The search shows no code importsLearningRateExpfromdeepmd.tf.utils, and the test file explicitly importsLearningRateExpfromdeepmd.dpmodel.utils.learning_rate. TheLearningRateScheduleexported fromdeepmd.tf.utilsis the TensorFlow-specific wrapper class, not a renamed version of a previously exported symbol. This is the correct architecture: core learning rate algorithms (LearningRateExp) reside in dpmodel, while framework-specific wrappers (LearningRateSchedule) reside in tf/pt/pd submodules.Likely an incorrect or invalid review comment.
source/tests/pd/model/test_model.py (2)
51-53: LGTM!The import correctly uses
LearningRateSchedulefor the TensorFlow side of the consistency test, aligning with the refactored API.
228-236: LGTM!The dict-based
LearningRateScheduleconstruction is correct and consistent with the PT test file, providing all required configuration parameters for the exponential decay schedule.source/tests/tf/test_lr.py (3)
1-21: LGTM! Well-structured test file with proper imports and clear documentation.The test module correctly separates validation tests from build/integration tests, and the docstring clearly states that core algorithm tests are in dpmodel tests.
23-44: LGTM! Validation tests properly verify error handling.The tests correctly validate that
ValueErroris raised for missingstart_lrandRuntimeErroris raised when accessingvalue()orbase_lrbeforebuild()is called.
47-110: LGTM! Build and integration tests are comprehensive.The tests properly validate:
- Tensor output type and dtype
- Default scheduler type inference
- Value consistency between TF tensor and
BaseLR.value()- Accessor methods work correctly before and after build
deepmd/pt/train/training.py (3)
275-278: LGTM! Correct stop_steps assignment.Setting
stop_stepstoself.num_stepsaligns with the unified BaseLR interface where warmup is now handled internally.
698-702: LGTM! Lambda LR calculation is correct.The lambda correctly computes the learning rate multiplier by dividing the absolute LR value by
start_lr. Thestep + self.start_stepoffset properly handles training resumption.
797-802: LGTM! Simplified pref_lr assignment.With warmup logic now handled internally by
BaseLR, the training loop correctly uses the current learning rate directly without conditional warmup branching.source/tests/pt/test_lr.py (4)
22-24: LGTM! Corrected decay_steps range.The constraint
decay_steps ∈ [400, 500]withstop_steps ∈ [500, 1500]ensuresdecay_stepsnever exceedsstop_steps, which would otherwise raise aValueErrorinLearningRateExp.
35-42: LGTM! Correct migration to dictionary-based construction.The
LearningRateScheduleconstructor now uses a dictionary with explicit type specification, aligning with the unified API.
77-83: LGTM! Good refactor to use local variable.Using
decay_step_for_rateinstead of mutatingself.decay_stepprevents side effects that could affect subsequent test iterations.
121-139: LGTM! Cosine scheduler test correctly migrated.The test validates basic cosine decay behavior including start, end, and midpoint values with the new keyword-argument API.
deepmd/tf/train/trainer.py (2)
106-118: LGTM! Clean refactor to use LearningRateSchedule.The helper function correctly:
- Extracts and handles
scale_by_workerseparately- Passes filtered params to
LearningRateSchedule- Returns properly typed tuple
429-434: LGTM! Simplified logging message.The log message now focuses on the essential learning rate values (start, current, final) without exposing internal decay parameters, which aligns with the abstracted scheduler interface.
deepmd/pd/train/training.py (3)
241-244: LGTM! Consistent with PyTorch implementation.The
get_lrfunction follows the same pattern as the PyTorch trainer, settingstop_stepsto total steps and usingBaseLRdirectly.
582-586: LGTM! Paddle LambdaDecay correctly configured.The lambda function matches the PyTorch pattern, computing the LR multiplier as
lr_exp.value(step + start_step) / lr_exp.start_lrfor proper LR scheduling with resume support.
747-749: LGTM! Simplified pref_lr assignment consistent with PyTorch.The warmup-free
pref_lr = cur_lrassignment aligns with the unified approach where warmup is handled internally byBaseLR.deepmd/tf/utils/learning_rate.py (2)
30-36: LGTM on constructor design.The constructor correctly stores the configuration dictionary and validates the required
start_lrparameter. The lazy initialization of_base_lris appropriate sincestop_stepsis only known at build time.
68-100: Consider potential performance implications oftf.numpy_function.Using
tf.numpy_functionwraps Python/NumPy execution, which works correctly but has implications:
- It breaks TF graph optimization and XLA compilation
- The function runs on CPU even in GPU-enabled sessions
This is acceptable for learning rate computation (low-frequency, scalar operation), but worth noting in documentation if performance-critical use cases arise.
source/tests/pd/test_lr.py (3)
21-23: Good adjustment to test parameter ranges.The comment correctly notes that
decay_stepsmust not exceedstop_steps. The rangenp.arange(400, 501, 100)produces[400, 500], which correctly stays within the minimumstop_stepsof 500.
34-41: Correct dict-based construction for LearningRateSchedule.The test properly uses the new dictionary-based API for
LearningRateSchedule, matching the expected interface wherestop_stepsis provided duringbuild().
76-83: Good refactor to use local variable.Using
decay_step_for_rateinstead of modifyingself.decay_stepavoids unintended state mutation between test iterations. This is a clean improvement.source/tests/universal/dpmodel/utils/test_learning_rate.py (4)
18-27: LGTM on basic decay test.The test correctly verifies:
- Initial LR at step 0 equals
start_lr(1e-3)- Final LR at step 10000 equals
stop_lr(1e-5)The tolerances (
rtol=1e-10andrtol=1e-5) are appropriate for the precision requirements.
80-93: Verify warmup boundary behavior at exactlywarmup_steps.The test at line 92 checks
lr.value(1000)equals1e-3, but step 1000 is exactly atwarmup_steps=1000. Based on the implementation inBaseLR.value()(line 162 in learning_rate.py), the condition isstep < self.warmup_steps, so step 1000 would enter the decay phase, not warmup.At step 1000 (decay_step=0), the decay phase returns
start_lrsincedecay_rate^0 = 1. So the assertion is correct, but the test comment could clarify this boundary behavior.
147-163: Good coverage for array/vectorized inputs.Testing array inputs ensures JIT compatibility across backends. The assertions correctly verify:
- Shape preservation (
lrs.shape == (5,))- Warmup at step 0 returns
warmup_start_lr(0.0)- End of warmup at step 1000 returns
start_lr(1e-3)
207-217: Good validation test for decay_steps constraint.The test correctly verifies that
decay_steps > stop_stepsraisesValueErrorwith appropriate message content checks.deepmd/utils/argcheck.py (4)
2483-2515: Correct mutual exclusion validation for stop_lr/stop_ratio.The validation correctly:
- Checks if both are provided (error)
- Checks if neither is provided (error)
- Returns
Trueon successThe error messages are clear and informative, which outweighs the TRY003 style concern.
2536-2546: Correct warmup validation logic.The check correctly identifies non-zero
warmup_steps(since default is 0) and non-Nonewarmup_ratioas the conditions for mutual exclusion. This aligns with the behavior wherewarmup_steps=0means "no warmup" rather than "explicit zero warmup steps."
2549-2619: Well-designed common argument builder.The
_learning_rate_common_argsfunction effectively reduces code duplication by:
- Building common args (start_lr, stop_lr, stop_ratio)
- Accepting
extra_argsfor scheduler-specific parameters- Appending warmup-related args at the end
This design allows
learning_rate_expto injectdecay_stepsanddecay_ratewhile sharing the rest.
2690-2714: Good integration of validation checks.The nested
_check_lr_argsfunction composes both validation checks and is correctly passed toextra_check. This ensures validation runs during argument normalization.deepmd/dpmodel/utils/learning_rate.py (4)
31-99: Solid base class design with warmup support.The
BaseLR.__init__correctly:
- Computes
stop_lrfromstop_ratioif provided- Computes
warmup_stepsfromwarmup_ratioif provided- Validates step ranges at runtime
- Computes derived values (
warmup_start_lr,decay_stop_steps)Note: The
kwargsparameter (flagged by static analysis) is necessary for the plugin system's**kwargsforwarding pattern.
137-166: Correct warmup and decay phase implementation.The
value()method correctly:
- Handles scalar vs array inputs
- Uses array API for backend-agnostic operations
- Computes linear warmup interpolation
- Delegates decay computation to subclass
_decay_value()- Uses
xp.wherefor JIT-compatible branchingThe use of
xp.maximum(..., 0.0)fordecay_stepensures negative values don't propagate during warmup phase.
240-250: Good validation and numerical stability handling.The implementation correctly:
- Validates
decay_steps <= decay_totalto prevent invalid configurations- Clamps
stop_lrto1e-10for log computation to avoidlog(0)- Stores original
stop_lrasmin_lrfor clampingThis ensures numerical stability while preserving the intended final learning rate behavior.
364-373: Correct cosine decay with boundary handling.The cosine annealing formula is correctly implemented:
- Uses
decay_stop_stepsinstead ofstop_stepsto account for warmup- Clamps to
min_lrfor steps beyond the decay phase usingxp.whereThis ensures the learning rate doesn't oscillate beyond the intended schedule.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request refactors the learning rate schedulers to use array_api_compat for backend-agnostic implementation, consolidating logic from TensorFlow, PyTorch, and PaddlePaddle backends into a unified dpmodel layer. The refactoring adds warmup support and flexible configuration options while ensuring JIT compatibility across backends.
Changes:
- Unified learning rate scheduler implementation in dpmodel using array API operations
- Added warmup functionality (warmup_steps, warmup_ratio, warmup_start_factor) with mutual exclusion validation
- Added stop_ratio parameter as alternative to stop_lr with mutual exclusion validation
- Updated all backend implementations to use the unified BaseLR, removing duplicated warmup logic
- Added comprehensive tests for core functionality and cross-backend consistency
Reviewed changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/dpmodel/utils/learning_rate.py | Core unified learning rate implementation with warmup and array API support |
| deepmd/utils/argcheck.py | Updated argument definitions with warmup parameters and mutual exclusion validation |
| deepmd/tf/utils/learning_rate.py | Simplified TensorFlow wrapper to use unified BaseLR |
| deepmd/tf/train/trainer.py | Updated to use new LearningRateSchedule API |
| deepmd/tf/fit/*.py | Updated type annotations for learning rate parameter |
| deepmd/pt/train/training.py | Removed local warmup implementation, delegated to BaseLR |
| deepmd/pt/utils/utils.py | Extended to_numpy_array to handle scalars and numpy arrays |
| deepmd/pd/train/training.py | Removed local warmup implementation, delegated to BaseLR |
| deepmd/pd/utils/utils.py | Extended to_numpy_array to handle scalars and numpy arrays |
| source/tests/universal/dpmodel/utils/test_learning_rate.py | New comprehensive tests for core learning rate functionality |
| source/tests/tf/test_lr.py | New tests for TensorFlow wrapper |
| source/tests/pt/test_lr.py | Updated tests to use keyword arguments |
| source/tests/pd/test_lr.py | Updated tests to use keyword arguments |
| source/tests/consistent/test_learning_rate.py | Enhanced consistency tests with warmup validation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
TODO:
|
- Refactor BaseLR in dpmodel to use array_api_compat for backend-agnostic implementation - Consolidate learning rate logic from TF/PT/PD backends into unified dpmodel layer - Use array API operations (xp.where, xp.clip, etc.) for JIT compatibility across backends - Add warmup support (warmup_steps, warmup_ratio, warmup_start_factor) during refactoring - Add stop_ratio parameter as alternative to stop_lr for flexible configuration - Implement mutual exclusion validation for stop_lr/stop_ratio and warmup_steps/warmup_ratio - Update all backends to use unified BaseLR implementation - Add comprehensive consistency tests across NumPy/PyTorch/JAX/array_api_strict backends
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/utils/learning_rate.py`:
- Around line 194-277: The constructor for the exponential scheduler (__init__
in LearningRateExp) currently allows decay_steps <= 0 which leads to
division-by-zero; add an early validation after reading decay_steps (and before
using it in decay rate computation and in _decay_value) to raise ValueError if
decay_steps <= 0 with a clear message like "decay_steps must be > 0"; ensure
this check is performed before comparing to decay_total and before any math that
divides by decay_steps so both the np.log(...) / (decay_total /
self.decay_steps) computation and the floor-division in _decay_value are
protected.
In `@deepmd/tf/train/trainer.py`:
- Around line 106-118: When building the learning-rate schedule in
DPTrainer.build (and the same logic around lines 242-246), guard the case
stop_batch == 0 and avoid calling LearningRateSchedule.build which passes
num_steps=0 into BaseLR; instead short-circuit and return a constant LR tensor
or a trivial LearningRateSchedule that does not call BaseLR.build. Concretely:
detect stop_batch == 0 before calling
LearningRateSchedule.build(self.global_step, self.stop_batch), and in that
branch create a fixed scalar/constant schedule (or skip build and set lr to a
constant tensor) so BaseLR validation (num_steps > 0) is not triggered.
🧹 Nitpick comments (5)
deepmd/pt/utils/utils.py (1)
221-238: Overload signatures are incomplete for new input types.The function signature now accepts
torch.Tensor | np.ndarray | float | None, but the@overloaddeclarations only covertorch.TensorandNone. This may cause type checkers to flag valid calls withnp.ndarrayorfloatarguments as errors.♻️ Proposed fix to add missing overloads
`@overload` def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ... `@overload` def to_numpy_array(xx: None) -> None: ... +@overload +def to_numpy_array(xx: np.ndarray) -> np.ndarray: ... + + +@overload +def to_numpy_array(xx: float) -> np.ndarray: ... + + def to_numpy_array( xx: torch.Tensor | np.ndarray | float | None, ) -> np.ndarray | None:deepmd/pd/train/training.py (1)
241-245: Avoid mutatinglr_paramsin-place when injectingnum_steps.This currently mutates the caller’s config dict (and, for multi-task, potentially multiple shared dicts). Prefer copying before modification.
Proposed fix
def get_lr(lr_params: dict[str, Any]) -> BaseLR: - lr_params["num_steps"] = self.num_steps - lr_schedule = BaseLR(**lr_params) + _params = dict(lr_params) + _params["num_steps"] = self.num_steps + lr_schedule = BaseLR(**_params) return lr_schedulesource/tests/pd/model/test_model.py (1)
306-306: Consider using keyword arguments for consistency.This instantiation uses positional arguments while the codebase (including
_get_dp_lrat line 229-236 and the PT test file) uses keyword arguments. Consider using keyword arguments for clarity and consistency with the new API style.♻️ Suggested change
- my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.num_steps) + my_lr = MyLRExp( + start_lr=self.start_lr, + stop_lr=self.stop_lr, + decay_steps=self.decay_steps, + num_steps=self.num_steps, + )source/tests/pt/model/test_model.py (1)
306-306: Consider using keyword arguments for consistency.This instantiation uses positional arguments while the codebase (including
_get_dp_lrat lines 229-236 andsource/tests/pt/test_lr.py) uses keyword arguments. Consider using keyword arguments for clarity and consistency with the new API style.♻️ Suggested change
- my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.num_steps) + my_lr = MyLRExp( + start_lr=self.start_lr, + stop_lr=self.stop_lr, + decay_steps=self.decay_steps, + num_steps=self.num_steps, + )source/tests/universal/dpmodel/utils/test_learning_rate.py (1)
204-240: Add tests for mutual exclusion parameter validation.The implementation validates that
stop_lrandstop_lr_rateare mutually exclusive, and thatwarmup_stepsandwarmup_ratioare mutually exclusive. Consider adding test cases to verify these validations raise appropriate errors:
- Both
stop_lrandstop_lr_rateprovided- Both
warmup_stepsandwarmup_ratioprovided- Neither parameter provided in each pair (if applicable)
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
deepmd/dpmodel/utils/learning_rate.pydeepmd/pd/train/training.pydeepmd/pd/utils/utils.pydeepmd/pt/train/training.pydeepmd/pt/utils/utils.pydeepmd/tf/fit/dipole.pydeepmd/tf/fit/dos.pydeepmd/tf/fit/ener.pydeepmd/tf/fit/fitting.pydeepmd/tf/fit/polar.pydeepmd/tf/train/trainer.pydeepmd/tf/utils/__init__.pydeepmd/tf/utils/learning_rate.pydeepmd/utils/argcheck.pysource/tests/consistent/test_learning_rate.pysource/tests/pd/model/test_model.pysource/tests/pd/test_lr.pysource/tests/pt/model/test_model.pysource/tests/pt/test_lr.pysource/tests/tf/test_lr.pysource/tests/universal/dpmodel/utils/test_learning_rate.py
🚧 Files skipped from review as they are similar to previous changes (8)
- deepmd/tf/fit/fitting.py
- source/tests/consistent/test_learning_rate.py
- deepmd/tf/fit/dipole.py
- source/tests/tf/test_lr.py
- deepmd/tf/fit/ener.py
- deepmd/tf/fit/polar.py
- deepmd/pd/utils/utils.py
- source/tests/pd/test_lr.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR `#3905`.
Applied to files:
source/tests/pt/test_lr.py
🧬 Code graph analysis (11)
deepmd/pt/utils/utils.py (1)
deepmd/pt/model/network/network.py (1)
Tensor(34-35)
deepmd/tf/train/trainer.py (1)
deepmd/tf/utils/learning_rate.py (1)
LearningRateSchedule(20-123)
deepmd/tf/utils/__init__.py (1)
deepmd/tf/utils/learning_rate.py (1)
LearningRateSchedule(20-123)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
BaseLR(25-189)value(146-189)
source/tests/pd/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (3)
LearningRateSchedule(20-123)build(68-100)start_lr(38-47)
deepmd/dpmodel/utils/learning_rate.py (3)
deepmd/tf/utils/learning_rate.py (2)
start_lr(38-47)value(102-123)deepmd/pd/train/training.py (1)
step(722-963)deepmd/pt/train/training.py (1)
step(773-1142)
source/tests/pt/test_lr.py (2)
deepmd/tf/utils/learning_rate.py (5)
LearningRateSchedule(20-123)start_lr(38-47)base_lr(50-66)build(68-100)value(102-123)deepmd/dpmodel/utils/learning_rate.py (2)
LearningRateExp(193-301)value(146-189)
deepmd/pd/train/training.py (3)
deepmd/pt/train/training.py (1)
step(773-1142)deepmd/dpmodel/utils/learning_rate.py (1)
value(146-189)deepmd/tf/utils/learning_rate.py (2)
value(102-123)start_lr(38-47)
source/tests/universal/dpmodel/utils/test_learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (3)
LearningRateCosine(305-397)LearningRateExp(193-301)value(146-189)
deepmd/utils/argcheck.py (1)
deepmd/utils/plugin.py (2)
register(41-59)register(122-141)
source/tests/pt/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (3)
LearningRateSchedule(20-123)build(68-100)start_lr(38-47)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py
34-34: Avoid specifying long messages outside the exception class
(TRY003)
65-65: Avoid specifying long messages outside the exception class
(TRY003)
122-122: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/dpmodel/utils/learning_rate.py
40-40: Unused method argument: kwargs
(ARG002)
75-78: Avoid specifying long messages outside the exception class
(TRY003)
80-83: Avoid specifying long messages outside the exception class
(TRY003)
96-99: Avoid specifying long messages outside the exception class
(TRY003)
109-109: Avoid specifying long messages outside the exception class
(TRY003)
111-111: Avoid specifying long messages outside the exception class
(TRY003)
113-113: Avoid specifying long messages outside the exception class
(TRY003)
264-266: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/utils/argcheck.py
2506-2509: Avoid specifying long messages outside the exception class
(TRY003)
2511-2514: Avoid specifying long messages outside the exception class
(TRY003)
2542-2545: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (31)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (36)
deepmd/tf/utils/__init__.py (1)
9-11: LGTM!The public export correctly updates from
LearningRateExptoLearningRateSchedule, aligning with the unified LR scheduler refactoring. This is a breaking change as noted in the PR labels.Also applies to: 20-27
deepmd/pt/utils/utils.py (1)
234-238: LGTM!The new input handling logic correctly short-circuits for scalars and numpy arrays. The assert at line 238 is reachable and serves as a guard to ensure only
torch.Tensorinputs proceed to the conversion logic.deepmd/pt/train/training.py (3)
275-278: LGTM!The
get_lrfunction correctly injectsnum_stepsinto the learning rate parameters before constructing theBaseLRinstance, aligning with the new unified LR scheduler API.
698-702: LGTM!The LambdaLR lambda correctly computes the multiplicative factor by dividing the absolute learning rate value by
start_lr. This properly delegates warmup and decay logic to theBaseLR.value()method while maintaining compatibility with PyTorch's LambdaLR scheduler.Also applies to: 722-726
797-799: LGTM!The
pref_lrsimplification is correct since warmup handling is now encapsulated within theBaseLR.value()method. The scheduler'sget_last_lr()already returns the warmup-adjusted learning rate.deepmd/tf/utils/learning_rate.py (3)
30-36: LGTM!The constructor correctly validates the presence of
start_lrand stores the configuration for deferred schedule building.
68-100: LGTM!The
build()method correctly:
- Creates a parameter copy to avoid mutation
- Defaults to
"exp"type for backward compatibility- Uses
tf.numpy_function(TF 2.x API) for runtime LR evaluation- Preserves tensor shape information and casts to
float32for model compatibility
102-123: LGTM!The
value()method correctly guards against unbuilt state and properly delegates to theBaseLR.value()method with appropriate type conversion.deepmd/tf/fit/dos.py (1)
651-668: LGTM!The docstring correctly updates the
lrparameter type fromLearningRateExptoLearningRateSchedule, aligning with the renamed class. The method implementation usinglr.start_lr()is compatible with the new API.deepmd/tf/train/trainer.py (1)
429-434: Verify LR step indexing convention (value(stop_batch)vsvalue(stop_batch - 1)).The log prints
self.lr.value(stop_batch)as “final lr”. Depending on whether “num_steps” is treated as a count or a last-step index, the LR actually used on the last optimizer update is oftenvalue(stop_batch - 1). Please confirm this matches the new BaseLR convention and cross-backend tests.deepmd/utils/argcheck.py (1)
2483-2547: Nice consolidation + constraints; ensure examples/docs are updated for required stop settings.Requiring exactly one of
stop_lr/stop_lr_rateand the warmup mutual exclusion makes sense, but it’s a user-facing breaking change (configs that omitted both will now hard-fail). Please ensure the exampleinput.jsonfiles and docs are updated accordingly (matches the PR TODOs).Optional: address Ruff TRY003 if CI enforces it
- raise ValueError( - "stop_lr and stop_lr_rate are mutually exclusive. " - f"Got stop_lr={data['stop_lr']}, stop_lr_rate={data['stop_lr_rate']}" - ) + raise ValueError( + f"stop_lr and stop_lr_rate are mutually exclusive: stop_lr={data['stop_lr']}, " + f"stop_lr_rate={data['stop_lr_rate']}" + ) # noqa: TRY003Also applies to: 2549-2672, 2689-2715
deepmd/dpmodel/utils/learning_rate.py (1)
146-190: Request verification: “num_steps” and “final LR” off-by-one semantics across backends.With
decay_num_steps = num_steps - warmup_stepsand cosine usingstep / decay_num_steps,stop_lris reached atstep == num_steps(after warmup). Many training loops run steps[0, num_steps-1], so they may never hit the exact stop LR. Please confirm this is intentional and consistent across TF/PT/PD integrations + the new consistency tests.Also (optional), consider keeping
stepintegral through exp’sstep // decay_stepsto avoid accidental non-integer step inputs being silently accepted.Optional refactor: integer exponent in exp decay
def _decay_value(self, step: int | Array) -> Array: @@ - step_lr = self.start_lr * xp.pow( + step_i = xp.astype(step, xp.int64) + exp_i = xp.astype(step_i // self.decay_steps, xp.float64) + step_lr = self.start_lr * xp.pow( xp.asarray(self.decay_rate, device=array_api_compat.device(step)), - xp.astype(step // self.decay_steps, xp.float64), + exp_i, )Also applies to: 278-301, 363-397
deepmd/pd/train/training.py (1)
582-586: Verify: Adam scheduler path behavior for multi-task +learning_rate_dict, and LR end-step convention.
LambdaDecayreferencesself.lr_exp.start_lr; ifself.lr_expis a dict (multi-task +learning_rate_dict), this will break unless that combo is intentionally unsupported—in which case, please fail fast with a clear error earlier.Also please confirm the intended “final step” convention for
value(step + start_step)vs BaseLR’snum_stepsdefinition (to ensure stop LR is reached when expected).source/tests/pd/model/test_model.py (5)
25-25: LGTM!Import correctly updated to use
LearningRateExpfromdeepmd.dpmodel.utils.learning_ratealiased asMyLRExpfor the Paddle test.
51-53: LGTM!Import updated to use the new
LearningRateSchedulewrapper for TensorFlow.
111-111: LGTM!Variable renamed from
stop_stepstonum_stepsto align with the unified API.
140-140: LGTM!Correctly passes
self.num_stepstodp_lr.build()method.
228-236: LGTM!The
_get_dp_lrmethod correctly returns aLearningRateSchedulewith a dict payload containing the required configuration parameters.source/tests/pt/model/test_model.py (5)
34-34: LGTM!Import correctly updated to use
LearningRateExpfromdeepmd.pt.utils.learning_ratealiased asMyLRExpfor the PyTorch test.
51-53: LGTM!Import updated to use the new
LearningRateSchedulewrapper for TensorFlow.
111-111: LGTM!Variable renamed from
stop_stepstonum_stepsto align with the unified API.
140-140: LGTM!Correctly passes
self.num_stepstodp_lr.build()method.
228-236: LGTM!The
_get_dp_lrmethod correctly returns aLearningRateSchedulewith a dict payload containing the required configuration parameters.source/tests/pt/test_lr.py (7)
13-15: LGTM!Import updated to use the new
LearningRateSchedulefrom TensorFlow utils.
22-24: LGTM!Good addition of the comment explaining the constraint that
decay_stepsmust not exceednum_steps. The variable naming change fromstop_stepstonum_stepsaligns with the unified API.
35-42: LGTM!Correctly migrated to dict-based payload for
LearningRateScheduleconstruction.
48-53: LGTM!Updated to use keyword arguments for
LearningRateExpinstantiation, improving clarity.
77-84: Good use of local variable to avoid side effects.Using
decay_step_for_rateas a local variable instead of modifyingself.decay_stepprevents unintended side effects in subsequent test iterations.
85-91: LGTM!Properly uses keyword arguments and sets
stop_lr=1e-10as a minimal value for the decay rate override test.
121-138: LGTM!The
TestLearningRateCosineclass properly tests the cosine annealing schedule with updated variable naming (num_steps) and keyword arguments.source/tests/universal/dpmodel/utils/test_learning_rate.py (6)
1-12: LGTM!Good file setup with appropriate imports for testing the unified learning rate schedules.
15-51: LGTM!Comprehensive tests for basic exponential decay, including
stop_lr_rateparameter and explicitdecay_rateoverride functionality.
53-75: LGTM!Good coverage of cosine annealing basics, including the midpoint value check at
(start_lr + stop_lr) / 2.
77-145: LGTM!Thorough warmup testing covering:
warmup_stepswith both exponential and cosine decaywarmup_ratioparameterwarmup_start_factorparameter- Zero warmup edge case
147-179: LGTM!Good tests for array input support, which is essential for JIT compatibility across backends.
181-202: LGTM!Important tests verifying that learning rates clamp to
stop_lrwhen steps exceednum_steps.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
- Add comprehensive documentation for warmup parameters (warmup_steps, warmup_ratio, warmup_start_factor) - Document stop_lr_rate as an alternative to stop_lr for ratio-based specification - Add complete documentation for cosine annealing scheduler - Update theory section with separate warmup and decay phase formulas - Provide 7 configuration examples covering both exp and cosine types with warmup variants
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@doc/train/training-advanced.md`:
- Around line 86-93: The exponential decay formula lr(t) = start_lr * decay_rate
^ ( (t - warmup_steps) / decay_steps ) is missing that the division is integer
(floor) division, causing confusion about discrete updates; update the
documentation near lr(t) and the discussion of start_lr/stop_lr to either
replace “/ decay_steps” with an explicit floor notation (e.g., ⌊(t -
warmup_steps) / decay_steps⌋) or add a short note stating the division is floor
(integer) division so readers know the learning rate updates in discrete steps
rather than continuously, and reference the symbols lr(t), start_lr, decay_rate,
warmup_steps, and decay_steps when making the change.
🧹 Nitpick comments (3)
doc/train/training-advanced.md (3)
66-70: Replace hard tabs with spaces in JSON examples.The JSON configuration examples contain hard tabs, which can cause rendering inconsistencies. Markdown files should use spaces for indentation.
♻️ Formatting recommendation
Replace all hard tabs with spaces (typically 4 spaces per indentation level) in all JSON code blocks throughout the examples section.
Also applies to: 120-125, 131-136, 144-151, 159-165, 175-179, 187-191, 197-201, 209-215
90-93: Add language identifiers to code blocks.The code blocks on lines 90-93 and 100-103 lack language identifiers. Adding identifiers improves readability and enables proper syntax highlighting.
♻️ Suggested improvement
-``` +```text lr(t) = start_lr * decay_rate ^ ( (t - warmup_steps) / decay_steps ) ```Apply the same change to the cosine formula block.
Also applies to: 100-103
116-116: Use proper heading levels instead of bold text for examples.Example titles (lines 116, 127, 140, 155, 183, 193, 205) use bold text instead of proper Markdown headings. Using heading levels (e.g.,
####) improves document structure and navigation.♻️ Suggested improvement
-**Example 1: Basic exponential decay without warmup** +#### Example 1: Basic exponential decay without warmupApply the same change to all seven example titles.
Also applies to: 127-127, 140-140, 155-155, 183-183, 193-193, 205-205
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
doc/train/training-advanced.md
🧰 Additional context used
🪛 LanguageTool
doc/train/training-advanced.md
[grammar] ~58-~58: Ensure spelling is correct
Context: ..."cosine"`). Both types support optional warmup and can use either absolute stopping le...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
🪛 markdownlint-cli2 (0.18.1)
doc/train/training-advanced.md
66-66: Hard tabs
Column: 1
(MD010, no-hard-tabs)
66-66: Hard tabs
Column: 9
(MD010, no-hard-tabs)
67-67: Hard tabs
Column: 1
(MD010, no-hard-tabs)
67-67: Hard tabs
Column: 13
(MD010, no-hard-tabs)
68-68: Hard tabs
Column: 1
(MD010, no-hard-tabs)
68-68: Hard tabs
Column: 12
(MD010, no-hard-tabs)
69-69: Hard tabs
Column: 1
(MD010, no-hard-tabs)
69-69: Hard tabs
Column: 16
(MD010, no-hard-tabs)
70-70: Hard tabs
Column: 1
(MD010, no-hard-tabs)
70-70: Hard tabs
Column: 13
(MD010, no-hard-tabs)
90-90: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
100-100: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
116-116: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
120-120: Hard tabs
Column: 1
(MD010, no-hard-tabs)
120-120: Hard tabs
Column: 9
(MD010, no-hard-tabs)
121-121: Hard tabs
Column: 1
(MD010, no-hard-tabs)
121-121: Hard tabs
Column: 13
(MD010, no-hard-tabs)
122-122: Hard tabs
Column: 1
(MD010, no-hard-tabs)
122-122: Hard tabs
Column: 12
(MD010, no-hard-tabs)
123-123: Hard tabs
Column: 1
(MD010, no-hard-tabs)
123-123: Hard tabs
Column: 16
(MD010, no-hard-tabs)
127-127: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
131-131: Hard tabs
Column: 1
(MD010, no-hard-tabs)
131-131: Hard tabs
Column: 9
(MD010, no-hard-tabs)
132-132: Hard tabs
Column: 1
(MD010, no-hard-tabs)
132-132: Hard tabs
Column: 13
(MD010, no-hard-tabs)
133-133: Hard tabs
Column: 1
(MD010, no-hard-tabs)
133-133: Hard tabs
Column: 17
(MD010, no-hard-tabs)
134-134: Hard tabs
Column: 1
(MD010, no-hard-tabs)
134-134: Hard tabs
Column: 16
(MD010, no-hard-tabs)
140-140: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
144-144: Hard tabs
Column: 1
(MD010, no-hard-tabs)
144-144: Hard tabs
Column: 9
(MD010, no-hard-tabs)
145-145: Hard tabs
Column: 1
(MD010, no-hard-tabs)
145-145: Hard tabs
Column: 13
(MD010, no-hard-tabs)
146-146: Hard tabs
Column: 1
(MD010, no-hard-tabs)
146-146: Hard tabs
Column: 12
(MD010, no-hard-tabs)
147-147: Hard tabs
Column: 1
(MD010, no-hard-tabs)
147-147: Hard tabs
Column: 16
(MD010, no-hard-tabs)
148-148: Hard tabs
Column: 1
(MD010, no-hard-tabs)
148-148: Hard tabs
Column: 17
(MD010, no-hard-tabs)
149-149: Hard tabs
Column: 1
(MD010, no-hard-tabs)
155-155: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
159-159: Hard tabs
Column: 1
(MD010, no-hard-tabs)
159-159: Hard tabs
Column: 9
(MD010, no-hard-tabs)
160-160: Hard tabs
Column: 1
(MD010, no-hard-tabs)
160-160: Hard tabs
Column: 13
(MD010, no-hard-tabs)
161-161: Hard tabs
Column: 1
(MD010, no-hard-tabs)
161-161: Hard tabs
Column: 17
(MD010, no-hard-tabs)
162-162: Hard tabs
Column: 1
(MD010, no-hard-tabs)
162-162: Hard tabs
Column: 16
(MD010, no-hard-tabs)
163-163: Hard tabs
Column: 1
(MD010, no-hard-tabs)
163-163: Hard tabs
Column: 17
(MD010, no-hard-tabs)
175-175: Hard tabs
Column: 1
(MD010, no-hard-tabs)
175-175: Hard tabs
Column: 9
(MD010, no-hard-tabs)
176-176: Hard tabs
Column: 1
(MD010, no-hard-tabs)
176-176: Hard tabs
Column: 13
(MD010, no-hard-tabs)
177-177: Hard tabs
Column: 1
(MD010, no-hard-tabs)
177-177: Hard tabs
Column: 12
(MD010, no-hard-tabs)
183-183: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
187-187: Hard tabs
Column: 1
(MD010, no-hard-tabs)
187-187: Hard tabs
Column: 9
(MD010, no-hard-tabs)
188-188: Hard tabs
Column: 1
(MD010, no-hard-tabs)
188-188: Hard tabs
Column: 13
(MD010, no-hard-tabs)
189-189: Hard tabs
Column: 1
(MD010, no-hard-tabs)
189-189: Hard tabs
Column: 12
(MD010, no-hard-tabs)
193-193: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
197-197: Hard tabs
Column: 1
(MD010, no-hard-tabs)
197-197: Hard tabs
Column: 9
(MD010, no-hard-tabs)
198-198: Hard tabs
Column: 1
(MD010, no-hard-tabs)
198-198: Hard tabs
Column: 13
(MD010, no-hard-tabs)
199-199: Hard tabs
Column: 1
(MD010, no-hard-tabs)
199-199: Hard tabs
Column: 17
(MD010, no-hard-tabs)
205-205: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
209-209: Hard tabs
Column: 1
(MD010, no-hard-tabs)
209-209: Hard tabs
Column: 9
(MD010, no-hard-tabs)
210-210: Hard tabs
Column: 1
(MD010, no-hard-tabs)
210-210: Hard tabs
Column: 13
(MD010, no-hard-tabs)
211-211: Hard tabs
Column: 1
(MD010, no-hard-tabs)
211-211: Hard tabs
Column: 12
(MD010, no-hard-tabs)
212-212: Hard tabs
Column: 1
(MD010, no-hard-tabs)
212-212: Hard tabs
Column: 17
(MD010, no-hard-tabs)
213-213: Hard tabs
Column: 1
(MD010, no-hard-tabs)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (1)
doc/train/training-advanced.md (1)
78-81: All documented parameter constraints are properly enforced in the implementation.The verification confirms that:
stop_lrandstop_lr_rateare mutually exclusive, with validation at lines 74-78 ofdeepmd/dpmodel/utils/learning_rate.py- Either
stop_lrorstop_lr_ratemust be provided, enforced at lines 79-83warmup_stepsandwarmup_ratioare mutually exclusive, validated at lines 95-99warmup_start_factordefaults to 0.0 as documented, with the default defined at line 39All constraints mentioned in the documentation match the actual implementation behavior. No changes needed.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
…ay mode - Rename parameter `stop_lr_rate` to `stop_lr_ratio` for clearer semantics - Add `smooth` parameter to LearningRateExp for continuous vs stepped decay - Add comprehensive mathematical formulas to learning rate docstrings - Fix dtype handling in pt/pd to_numpy_array (respect GLOBAL_NP_FLOAT_PRECISION) - Remove redundant validation and caching in TF LearningRateSchedule wrapper - Update documentation and examples with new parameter names and smooth mode Breaking change: `stop_lr_rate` is now `stop_lr_ratio` in learning rate config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🤖 Fix all issues with AI agents
In `@deepmd/pt/utils/utils.py`:
- Around line 237-246: The precision lookup uses reverse_precision_dict (built
from PT_PRECISION_DICT) against xx.dtype which fails for numpy arrays; move the
ndarray branch before the reverse lookup or add a separate numpy-dtype lookup
using NP_PRECISION_DICT. Concretely, check isinstance(xx, np.ndarray) first and
return xx.astype(NP_PRECISION_DICT.get(xx.dtype, desired_default)) or, if you
want to preserve original numpy precision, simply return xx as-is; otherwise
proceed with reverse_precision_dict -> prec -> NP_PRECISION_DICT lookup and the
torch.Tensor assertion.
In `@deepmd/tf/utils/learning_rate.py`:
- Around line 45-63: Ruff flags TRY003 for the RuntimeError message in the
base_lr property; either silence it by appending "# noqa: TRY003" to the raise
line or replace the RuntimeError with a small custom exception (e.g., class
LearningRateNotBuiltError(RuntimeError): pass) and raise
LearningRateNotBuiltError("Learning rate schedule is not built yet."). Update
the Raises section in the docstring accordingly and apply the same change to the
other occurrence covering lines 98-120 (same pattern referencing self._base_lr
and the raise).
- Around line 64-97: The code in build() hardcodes float64 for the numpy dtype
and tf.numpy_function Tout, which conflicts with the project's
GLOBAL_TF_FLOAT_PRECISION; update the _lr_value numpy conversion and the
tf.numpy_function Tout to use GLOBAL_TF_FLOAT_PRECISION (map to the matching
numpy dtype for np.asarray and to the TensorFlow dtype for Tout) so
BaseLR.value(step) is converted using that precision; ensure
GLOBAL_TF_FLOAT_PRECISION is imported/available in this module and that the
returned tf.Tensor uses that dtype consistently (you may need to cast the
tf.numpy_function result to GLOBAL_TF_FLOAT_PRECISION if required).
In `@deepmd/utils/argcheck.py`:
- Around line 2549-2581: The validation in _check_decay_steps_args incorrectly
rejects typical decay rates (e.g., 0.95); change the decay_rate check to allow
values in the (0, 1] range by replacing the condition with something like: if
decay_rate is not None and (decay_rate <= 0 or decay_rate > 1): raise
ValueError(f"decay_rate ({decay_rate}) must be > 0 and <= 1."), keeping the
lr_type and decay_steps logic unchanged.
In `@doc/train/training-advanced.md`:
- Around line 64-238: Replace tab characters with spaces in all JSON code blocks
(e.g., the "learning_rate" examples where keys like "learning_rate", "type",
"start_lr", "decay_steps" are indented with tabs), add a fenced-code language
label to every triple-backtick block (use ```json for JSON examples or ```text
for plain text), and convert any bold-as-heading lines (e.g., "**Basic
parameters**", "**Additional parameters for `exp` type only:**", "**Learning
rate formula for `exp` type:**", etc.) into proper Markdown headings (e.g.,
prepend appropriate `#/`##/###) so lint rules MD010, MD040 and MD036 are
satisfied.
- Around line 105-114: The doc reuses the name decay_steps for the cosine
denominator which conflicts with the exponential schedule's decay_steps; update
the cosine section to compute and use a distinct name (e.g., decay_phase_length
or decay_length) defined as decay_phase_length = numb_steps - warmup_steps and
use it in the cosine formula lr(t) = stop_lr + (start_lr - stop_lr) / 2 * (1 +
cos(pi * (t - warmup_steps) / decay_phase_length)); ensure references to
decay_steps in the cosine paragraph are renamed to this new symbol to avoid
confusion with the exp schedule.
♻️ Duplicate comments (1)
deepmd/dpmodel/utils/learning_rate.py (1)
31-123: Add explicitstart_lr > 0validation to avoid ZeroDivision/log-domain failures.
Even if “start_lr cannot be 0” by convention, current code will crash (e.g.,self.stop_lr / self.start_lrin cosine, andlog(clamped_stop_lr / self.start_lr)in exp) with a non-obvious stack trace if config is wrong.Also applies to: 441-442
🧹 Nitpick comments (5)
deepmd/pt/utils/utils.py (1)
222-227: Consider adding overloads for new input types.The function now accepts
np.ndarrayandfloatinputs, but the overloads only covertorch.TensorandNone. Type checkers may flag calls with these new types as invalid.♻️ Suggested overloads
`@overload` def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ... `@overload` def to_numpy_array(xx: None) -> None: ... +@overload +def to_numpy_array(xx: np.ndarray) -> np.ndarray: ... + + +@overload +def to_numpy_array(xx: float) -> np.ndarray: ... + + def to_numpy_array( xx: torch.Tensor | np.ndarray | float | None, ) -> np.ndarray | None:deepmd/utils/argcheck.py (1)
2483-2547: Stop-LR and warmup mutual-exclusion checks are clear and match the new config semantics.deepmd/dpmodel/utils/learning_rate.py (1)
333-361:xp.clip(..., None)may not be Array-API-strict; prefer one-sided clamp viaxp.maximum.
If you truly needarray_api_strictcompatibility,clip’s signature can be a gotcha across namespaces. A one-sided lower clamp avoids relying onNonesupport.Proposed change
@@ - # Clip to min_lr for numerical stability in JIT - step_lr = xp.clip(step_lr, self.min_lr, None) + # Clamp to min_lr for numerical stability in JIT + step_lr = xp.maximum(step_lr, self.min_lr) return step_lrsource/tests/universal/dpmodel/utils/test_learning_rate.py (2)
90-90: Useatolinstead ofrtolwhen comparing to zero.When
desired=0.0, the tolerance formula|actual - desired| <= atol + rtol * |desired|reduces to|actual| <= 0with defaultatol=0. This requires exact equality. Useatolfor near-zero comparisons to avoid fragile tests.Same issue applies to lines 104, 162, and 176.
♻️ Suggested fix
- np.testing.assert_allclose(lr.value(0), 0.0, rtol=1e-10) + np.testing.assert_allclose(lr.value(0), 0.0, atol=1e-10)Apply similar changes to lines 104, 162, and 176.
204-240: Consider adding tests for mutual-exclusion validation.Per the PR objectives, the implementation adds "mutual-exclusion validation for warmup_steps vs warmup_ratio" and "mutual-exclusion validation for stop_lr vs stop_ratio". These validation behaviors should be tested to ensure they raise appropriate errors when both parameters are specified.
♻️ Suggested additional tests
def test_stop_lr_and_stop_lr_ratio_mutual_exclusion(self) -> None: """Test that specifying both stop_lr and stop_lr_ratio raises ValueError.""" with self.assertRaises(ValueError): LearningRateExp( start_lr=1e-3, stop_lr=1e-5, stop_lr_ratio=0.01, num_steps=10000, decay_steps=5000, ) def test_warmup_steps_and_warmup_ratio_mutual_exclusion(self) -> None: """Test that specifying both warmup_steps and warmup_ratio raises ValueError.""" with self.assertRaises(ValueError): LearningRateExp( start_lr=1e-3, stop_lr=1e-5, num_steps=10000, decay_steps=5000, warmup_steps=1000, warmup_ratio=0.1, )
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
deepmd/dpmodel/utils/learning_rate.pydeepmd/pd/utils/utils.pydeepmd/pt/utils/utils.pydeepmd/tf/train/trainer.pydeepmd/tf/utils/learning_rate.pydeepmd/utils/argcheck.pydoc/train/training-advanced.mdsource/tests/universal/dpmodel/utils/test_learning_rate.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pd/utils/utils.py
🧰 Additional context used
🧬 Code graph analysis (4)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
BaseLR(25-189)value(146-189)
deepmd/tf/train/trainer.py (1)
deepmd/tf/utils/learning_rate.py (3)
LearningRateSchedule(20-119)start_lr(34-43)build(64-96)
deepmd/pt/utils/utils.py (1)
deepmd/pt/model/network/network.py (1)
Tensor(34-35)
deepmd/dpmodel/utils/learning_rate.py (3)
deepmd/tf/utils/learning_rate.py (2)
start_lr(34-43)value(98-119)deepmd/pd/train/training.py (1)
step(722-963)deepmd/pt/train/training.py (1)
step(773-1142)
🪛 LanguageTool
doc/train/training-advanced.md
[grammar] ~58-~58: Ensure spelling is correct
Context: ..."cosine"`). Both types support optional warmup and can use either absolute stopping le...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
🪛 markdownlint-cli2 (0.18.1)
doc/train/training-advanced.md
66-66: Hard tabs
Column: 1
(MD010, no-hard-tabs)
66-66: Hard tabs
Column: 9
(MD010, no-hard-tabs)
67-67: Hard tabs
Column: 1
(MD010, no-hard-tabs)
67-67: Hard tabs
Column: 13
(MD010, no-hard-tabs)
68-68: Hard tabs
Column: 1
(MD010, no-hard-tabs)
68-68: Hard tabs
Column: 12
(MD010, no-hard-tabs)
69-69: Hard tabs
Column: 1
(MD010, no-hard-tabs)
69-69: Hard tabs
Column: 16
(MD010, no-hard-tabs)
70-70: Hard tabs
Column: 1
(MD010, no-hard-tabs)
70-70: Hard tabs
Column: 13
(MD010, no-hard-tabs)
93-93: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
99-99: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
109-109: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
125-125: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
129-129: Hard tabs
Column: 1
(MD010, no-hard-tabs)
129-129: Hard tabs
Column: 9
(MD010, no-hard-tabs)
130-130: Hard tabs
Column: 1
(MD010, no-hard-tabs)
130-130: Hard tabs
Column: 13
(MD010, no-hard-tabs)
131-131: Hard tabs
Column: 1
(MD010, no-hard-tabs)
131-131: Hard tabs
Column: 12
(MD010, no-hard-tabs)
132-132: Hard tabs
Column: 1
(MD010, no-hard-tabs)
132-132: Hard tabs
Column: 16
(MD010, no-hard-tabs)
136-136: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
140-140: Hard tabs
Column: 1
(MD010, no-hard-tabs)
140-140: Hard tabs
Column: 9
(MD010, no-hard-tabs)
141-141: Hard tabs
Column: 1
(MD010, no-hard-tabs)
141-141: Hard tabs
Column: 13
(MD010, no-hard-tabs)
142-142: Hard tabs
Column: 1
(MD010, no-hard-tabs)
142-142: Hard tabs
Column: 18
(MD010, no-hard-tabs)
143-143: Hard tabs
Column: 1
(MD010, no-hard-tabs)
143-143: Hard tabs
Column: 16
(MD010, no-hard-tabs)
149-149: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
153-153: Hard tabs
Column: 1
(MD010, no-hard-tabs)
153-153: Hard tabs
Column: 9
(MD010, no-hard-tabs)
154-154: Hard tabs
Column: 1
(MD010, no-hard-tabs)
154-154: Hard tabs
Column: 13
(MD010, no-hard-tabs)
155-155: Hard tabs
Column: 1
(MD010, no-hard-tabs)
155-155: Hard tabs
Column: 12
(MD010, no-hard-tabs)
156-156: Hard tabs
Column: 1
(MD010, no-hard-tabs)
156-156: Hard tabs
Column: 16
(MD010, no-hard-tabs)
157-157: Hard tabs
Column: 1
(MD010, no-hard-tabs)
157-157: Hard tabs
Column: 17
(MD010, no-hard-tabs)
158-158: Hard tabs
Column: 1
(MD010, no-hard-tabs)
164-164: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
168-168: Hard tabs
Column: 1
(MD010, no-hard-tabs)
168-168: Hard tabs
Column: 9
(MD010, no-hard-tabs)
169-169: Hard tabs
Column: 1
(MD010, no-hard-tabs)
169-169: Hard tabs
Column: 13
(MD010, no-hard-tabs)
170-170: Hard tabs
Column: 1
(MD010, no-hard-tabs)
170-170: Hard tabs
Column: 18
(MD010, no-hard-tabs)
171-171: Hard tabs
Column: 1
(MD010, no-hard-tabs)
171-171: Hard tabs
Column: 16
(MD010, no-hard-tabs)
172-172: Hard tabs
Column: 1
(MD010, no-hard-tabs)
172-172: Hard tabs
Column: 17
(MD010, no-hard-tabs)
184-184: Hard tabs
Column: 1
(MD010, no-hard-tabs)
184-184: Hard tabs
Column: 9
(MD010, no-hard-tabs)
185-185: Hard tabs
Column: 1
(MD010, no-hard-tabs)
185-185: Hard tabs
Column: 13
(MD010, no-hard-tabs)
186-186: Hard tabs
Column: 1
(MD010, no-hard-tabs)
186-186: Hard tabs
Column: 12
(MD010, no-hard-tabs)
192-192: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
196-196: Hard tabs
Column: 1
(MD010, no-hard-tabs)
196-196: Hard tabs
Column: 9
(MD010, no-hard-tabs)
197-197: Hard tabs
Column: 1
(MD010, no-hard-tabs)
197-197: Hard tabs
Column: 13
(MD010, no-hard-tabs)
198-198: Hard tabs
Column: 1
(MD010, no-hard-tabs)
198-198: Hard tabs
Column: 12
(MD010, no-hard-tabs)
202-202: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
206-206: Hard tabs
Column: 1
(MD010, no-hard-tabs)
206-206: Hard tabs
Column: 9
(MD010, no-hard-tabs)
207-207: Hard tabs
Column: 1
(MD010, no-hard-tabs)
207-207: Hard tabs
Column: 13
(MD010, no-hard-tabs)
208-208: Hard tabs
Column: 1
(MD010, no-hard-tabs)
208-208: Hard tabs
Column: 18
(MD010, no-hard-tabs)
214-214: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
218-218: Hard tabs
Column: 1
(MD010, no-hard-tabs)
218-218: Hard tabs
Column: 9
(MD010, no-hard-tabs)
219-219: Hard tabs
Column: 1
(MD010, no-hard-tabs)
219-219: Hard tabs
Column: 13
(MD010, no-hard-tabs)
220-220: Hard tabs
Column: 1
(MD010, no-hard-tabs)
220-220: Hard tabs
Column: 12
(MD010, no-hard-tabs)
221-221: Hard tabs
Column: 1
(MD010, no-hard-tabs)
221-221: Hard tabs
Column: 17
(MD010, no-hard-tabs)
222-222: Hard tabs
Column: 1
(MD010, no-hard-tabs)
228-228: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
232-232: Hard tabs
Column: 1
(MD010, no-hard-tabs)
232-232: Hard tabs
Column: 9
(MD010, no-hard-tabs)
233-233: Hard tabs
Column: 1
(MD010, no-hard-tabs)
233-233: Hard tabs
Column: 13
(MD010, no-hard-tabs)
234-234: Hard tabs
Column: 1
(MD010, no-hard-tabs)
234-234: Hard tabs
Column: 12
(MD010, no-hard-tabs)
235-235: Hard tabs
Column: 1
(MD010, no-hard-tabs)
235-235: Hard tabs
Column: 16
(MD010, no-hard-tabs)
236-236: Hard tabs
Column: 1
(MD010, no-hard-tabs)
236-236: Hard tabs
Column: 11
(MD010, no-hard-tabs)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py
61-61: Avoid specifying long messages outside the exception class
(TRY003)
118-118: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/utils/utils.py
243-243: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/dpmodel/utils/learning_rate.py
40-40: Unused method argument: kwargs
(ARG002)
75-78: Avoid specifying long messages outside the exception class
(TRY003)
80-83: Avoid specifying long messages outside the exception class
(TRY003)
96-99: Avoid specifying long messages outside the exception class
(TRY003)
109-109: Avoid specifying long messages outside the exception class
(TRY003)
111-111: Avoid specifying long messages outside the exception class
(TRY003)
113-113: Avoid specifying long messages outside the exception class
(TRY003)
309-309: Avoid specifying long messages outside the exception class
(TRY003)
312-315: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/utils/argcheck.py
2506-2509: Avoid specifying long messages outside the exception class
(TRY003)
2511-2514: Avoid specifying long messages outside the exception class
(TRY003)
2542-2545: Avoid specifying long messages outside the exception class
(TRY003)
2575-2575: Avoid specifying long messages outside the exception class
(TRY003)
2579-2579: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
🔇 Additional comments (13)
deepmd/pt/utils/utils.py (2)
19-19: LGTM!Import is correctly added to support the new scalar conversion logic.
235-236: LGTM!Scalar handling correctly wraps numeric inputs into numpy arrays with the configured global precision.
deepmd/tf/utils/learning_rate.py (1)
20-44: Clean wrapper API (start_lr()+ params copy) and clear intent.deepmd/tf/train/trainer.py (1)
106-118:get_lr_and_coef()update is straightforward and keepsscale_by_workerout of LR params.doc/train/training-advanced.md (1)
9-52: Theory section reads consistent with the new warmup + (exp/cosine) design, including explicit floor() for stepped exp.deepmd/utils/argcheck.py (1)
2583-2756: Good consolidation: shared args + centralizedextra_checkkeeps LR variants consistent.deepmd/dpmodel/utils/learning_rate.py (2)
146-189: Warmup/decay split inBaseLR.value()is readable and matches the documented piecewise schedule.
443-477: Cosine decay clamps correctly past the decay phase end (prevents cosine “bounce back”).source/tests/universal/dpmodel/utils/test_learning_rate.py (5)
1-12: LGTM!Imports are appropriate for the test module.
15-50: LGTM!Good test coverage for basic exponential decay functionality, including
stop_lr_ratioanddecay_rateoverride.
53-74: LGTM!Good test coverage for cosine annealing, including the mathematically correct midpoint check.
147-178: LGTM!Good test coverage for array/batch inputs, which is important for JIT compatibility. Shape verification is appropriate.
181-201: LGTM!Good edge-case coverage for behavior beyond
num_steps, verifying the clamping tostop_lr.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
When decay_steps exceeds the decay phase (num_steps - warmup_steps) and decay_rate is not explicitly provided, automatically adjust decay_steps to a sensible default (capped at 100, or decay_total//100 + 1) instead of raising ValueError. This makes the learning rate scheduler more user-friendly by gracefully handling misconfigured decay_steps values. Changes: - LearningRateExp: auto-adjust decay_steps when >= decay_total - Update argcheck and training-advanced.md documentation - Update pd/pt/tf test_lr.py to use auto-adjusted decay_steps - Remove obsolete validation tests in test_learning_rate.py - Fix tf test dtype: float32 -> float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 36 out of 36 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/tf/utils/learning_rate.py`:
- Around line 92-107: The closure _lr_value currently references self._base_lr
at evaluation time, which can break if build() is called multiple times; capture
the BaseLR instance into a local variable (e.g., base_lr = self._base_lr) inside
LearningRateSchedule.build() before defining _lr_value, then have _lr_value use
that local base_lr instead of self._base_lr so each returned tensor is bound to
the BaseLR instance present at closure creation.
♻️ Duplicate comments (2)
deepmd/tf/utils/learning_rate.py (2)
49-66: TRY003 lint warning on RuntimeError message.This was already flagged in a previous review. Consider adding
# noqa: TRY003or defining a custom exception class if CI enforces this rule.
109-130: TRY003 lint warning on RuntimeError message (same asbase_lrproperty).Already addressed in past review. Apply the same fix (noqa comment or custom exception) here as well.
🧹 Nitpick comments (2)
deepmd/pt/utils/utils.py (1)
230-242: Addintto the type annotations for consistency.The implementation handles
intat line 239 (isinstance(xx, (float, int))), but the overloads and main signature don't include it. This creates a type annotation mismatch.♻️ Proposed fix to include int in type annotations
`@overload` -def to_numpy_array(xx: float) -> np.ndarray: ... +def to_numpy_array(xx: float | int) -> np.ndarray: ... def to_numpy_array( - xx: torch.Tensor | np.ndarray | float | None, + xx: torch.Tensor | np.ndarray | float | int | None, ) -> np.ndarray | None:deepmd/tf/utils/learning_rate.py (1)
38-47: Consider makingstart_lra property for API consistency.
base_lris defined as a property, butstart_lris a method. This inconsistency may confuse users of the API.♻️ Suggested refactor
- def start_lr(self) -> float: + `@property` + def start_lr(self) -> float: """ Get the starting learning rate. Returns ------- float The starting learning rate. """ return float(self._params["start_lr"])Please verify that callers use
start_lras a method (with parentheses) or update them if changing to a property:#!/bin/bash # Search for usages of start_lr to verify call patterns rg -n "\.start_lr\(" --type py
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
deepmd/pt/utils/utils.pydeepmd/tf/train/trainer.pydeepmd/tf/utils/learning_rate.pysource/tests/universal/dpmodel/utils/test_learning_rate.py
🚧 Files skipped from review as they are similar to previous changes (2)
- source/tests/universal/dpmodel/utils/test_learning_rate.py
- deepmd/tf/train/trainer.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/tf/utils/learning_rate.pydeepmd/pt/utils/utils.py
📚 Learning: 2024-10-26T02:09:01.365Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4258
File: deepmd/jax/utils/neighbor_stat.py:98-101
Timestamp: 2024-10-26T02:09:01.365Z
Learning: The function `to_jax_array` in `deepmd/jax/common.py` can handle `None` values, so it's safe to pass `None` to it without additional checks.
Applied to files:
deepmd/pt/utils/utils.py
🧬 Code graph analysis (2)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
BaseLR(25-189)value(146-189)
deepmd/pt/utils/utils.py (3)
deepmd/pd/utils/utils.py (3)
to_numpy_array(231-231)to_numpy_array(235-235)to_numpy_array(238-256)deepmd/dpmodel/common.py (1)
to_numpy_array(106-128)deepmd/pt/model/network/network.py (1)
Tensor(34-35)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py
65-65: Avoid specifying long messages outside the exception class
(TRY003)
129-129: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/utils/utils.py
249-249: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (3)
deepmd/pt/utils/utils.py (1)
241-254: LGTM! The numpy array handling is correctly positioned.The
np.ndarraycheck now comes before the precision lookup (lines 241-242), which correctly handles numpy arrays by converting toGLOBAL_NP_FLOAT_PRECISIONwithout attempting to usereverse_precision_dict(which contains PyTorch dtypes). The torch.Tensor path (lines 243-254) correctly maps throughPT_PRECISION_DICT→NP_PRECISION_DICTand handles bfloat16 conversion.deepmd/tf/utils/learning_rate.py (2)
34-36: LGTM!The defensive shallow copy via
dict(params)prevents mutation of the caller's dictionary, and lazy initialization of_base_lris appropriate for the two-phase initialization pattern.
133-135: LGTM!The
__all__export correctly exposes only the publicLearningRateScheduleclass.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5154 +/- ##
==========================================
- Coverage 81.94% 81.93% -0.01%
==========================================
Files 713 714 +1
Lines 73010 73380 +370
Branches 3616 3617 +1
==========================================
+ Hits 59825 60125 +300
- Misses 12022 12093 +71
+ Partials 1163 1162 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/utils/learning_rate.py`:
- Around line 322-329: The decay_rate calculation can divide by zero when
num_steps (hence decay_total) is 0; update the block that sets self.decay_rate
(the else branch using np.exp(np.log(clamped_stop_lr / self.start_lr) /
(decay_total / self.decay_steps)).item()) to guard against decay_total <= 0 or
self.decay_steps == 0: when either is zero, set self.decay_rate to 1.0 (no
decay) instead of performing the log/division, otherwise keep the existing
calculation using clamped_stop_lr, self.start_lr, decay_total and
self.decay_steps.
🧹 Nitpick comments (1)
deepmd/dpmodel/utils/learning_rate.py (1)
91-99: Clarify:warmup_steps=0is treated as "not provided".The check
has_warmup_steps = warmup_steps != 0means explicitly settingwarmup_steps=0while also providingwarmup_ratiowill silently usewarmup_ratioinstead of raising a mutual exclusivity error.This is likely intentional since
warmup_steps=0is the default, but it differs from thestop_lr/stop_lr_ratiocheck which usesis not None. Consider adding a brief comment clarifying this design choice for maintainability.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/dpmodel/utils/learning_rate.pydeepmd/tf/loss/dos.pydeepmd/tf/loss/ener.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: 1azyking
Repo: deepmodeling/deepmd-kit PR: 4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In `deepmd/pt/loss/ener_hess.py`, the `label` uses the key `"atom_ener"` intentionally to maintain consistency with the forked version.
Applied to files:
deepmd/tf/loss/dos.pydeepmd/tf/loss/ener.py
📚 Learning: 2024-10-05T03:06:02.372Z
Learnt from: 1azyking
Repo: deepmodeling/deepmd-kit PR: 4169
File: deepmd/utils/argcheck.py:1982-2117
Timestamp: 2024-10-05T03:06:02.372Z
Learning: The `loss_ener_hess` and `loss_ener` functions should remain separate to avoid confusion, despite code duplication.
Applied to files:
deepmd/tf/loss/ener.py
📚 Learning: 2026-01-10T04:29:25.299Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:44-113
Timestamp: 2026-01-10T04:29:25.299Z
Learning: The epsilon parameter (eps=1e-8) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` follows the standard PyTorch AdaMuon reference implementations (which typically use eps=1e-7) and should not be modified for consistency with reference implementations, despite being smaller than bfloat16's machine epsilon.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (2)
deepmd/tf/loss/dos.py (1)
deepmd/tf/env.py (1)
global_cvt_2_tf_float(409-422)
deepmd/tf/loss/ener.py (1)
deepmd/tf/env.py (2)
global_cvt_2_ener_float(425-438)global_cvt_2_tf_float(409-422)
🪛 Ruff (0.14.11)
deepmd/dpmodel/utils/learning_rate.py
40-40: Unused method argument: kwargs
(ARG002)
75-78: Avoid specifying long messages outside the exception class
(TRY003)
80-83: Avoid specifying long messages outside the exception class
(TRY003)
96-99: Avoid specifying long messages outside the exception class
(TRY003)
109-109: Avoid specifying long messages outside the exception class
(TRY003)
111-111: Avoid specifying long messages outside the exception class
(TRY003)
113-113: Avoid specifying long messages outside the exception class
(TRY003)
115-115: Avoid specifying long messages outside the exception class
(TRY003)
310-310: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (10)
deepmd/tf/loss/dos.py (1)
68-69: LGTM! Conversion wrappers ensure consistent tensor typing.The addition of
global_cvt_2_tf_floatforfind_dosandfind_atom_dosensures these indicator values are properly cast to the global TensorFlow precision before being used in the preference coefficient calculations. This aligns with the consistent pattern applied across other loss modules (e.g.,find_force,find_virialinener.py).deepmd/tf/loss/ener.py (3)
152-159: LGTM! Appropriate precision casting for each indicator field.The conversion function choices are correct:
find_energy→global_cvt_2_ener_float(matchespref_e's energy precision)find_force,find_virial,find_atom_ener,find_atom_pref,find_drdq→global_cvt_2_tf_float(matches their respective prefactors' precision)This ensures type consistency in the arithmetic operations that follow.
592-596: LGTM! Consistent withEnerStdLosspattern.The conversion function assignments mirror the pattern in
EnerStdLoss:find_energyuses energy precision while force/virial/atom_ener indicators use TF float precision. This maintains consistency across the loss module implementations.
935-936: LGTM! Energy-related indicators use energy precision.Both
find_energyandfind_ener_dipoleare energy-related indicators and appropriately useglobal_cvt_2_ener_float. The outer wrappers onpref_eandpref_ed(lines 951 and 960) will handle the final precision casting for the preference coefficients.deepmd/dpmodel/utils/learning_rate.py (6)
1-29: LGTM!The imports and class registration pattern using the plugin registry are well-structured and consistent with the codebase architecture.
126-146: LGTM!The abstract
_decay_valuemethod defines a clean contract for subclasses, with clear documentation about the step being relative to the end of warmup.
148-191: LGTM!The
value()method correctly handles:
- Scalar vs array input conversion
- Linear warmup phase with proper interpolation
- Decay phase delegation to subclass
_decay_value- Boundary continuity at
warmup_steps(decay starts at step=0 relative to warmup end, returningstart_lr)The use of
xp.whereensures correct branching behavior compatible with JIT compilation.
334-362: LGTM!The
_decay_valueimplementation correctly handles both stepped and smooth exponential decay modes, with proper array API usage and device placement for the decay_rate constant.
462-480: LGTM!Good handling of edge cases:
decay_num_steps=0returnsstart_lr(no decay phase)- Steps beyond
decay_num_stepsare clipped tomin_lrviaxp.whereThe cosine annealing formula implementation is mathematically correct and matches the docstring.
40-40: Note: Unusedkwargsis intentional for plugin extensibility.The static analysis flags
kwargsas unused (ARG002), but this is intentional for the plugin registry pattern—subclasses may receive additional parameters viasuper().__init__(**kwargs). No change needed.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@source/tests/tf/test_lr.py`:
- Around line 44-51: The test test_build_returns_tensor hardcodes tf.float64 for
the dtype but the implementation uses GLOBAL_TF_FLOAT_PRECISION; update the
assertion to compare lr_tensor.dtype against GLOBAL_TF_FLOAT_PRECISION (or
tf.as_dtype(GLOBAL_TF_FLOAT_PRECISION) if GLOBAL_TF_FLOAT_PRECISION is a Python
float type) so the test reads the expected dtype from the same source as
LearningRateSchedule.build and will pass for both float32 and float64
configurations.
♻️ Duplicate comments (1)
deepmd/tf/utils/learning_rate.py (1)
50-67: Consider silencing Ruff TRY003 if CI enforces it.The RuntimeError message is reasonable and clear. If CI enforces TRY003, add
# noqa: TRY003to line 66.
🧹 Nitpick comments (2)
deepmd/tf/utils/learning_rate.py (1)
96-102: Misleading comment about dtype precision.The comment says "(float64)" but
GLOBAL_TF_FLOAT_PRECISIONcan be eitherfloat32orfloat64depending on configuration. Consider removing the hardcoded precision from the comment to avoid confusion.Suggested fix
def _lr_value(step: np.ndarray) -> np.ndarray: - # Use GLOBAL_TF_FLOAT_PRECISION (float64) for learning rate, - # consistent with energy precision in TF backend + # Use GLOBAL_TF_FLOAT_PRECISION for learning rate, + # consistent with model precision in TF backend return np.asarray( base_lr.value(step), dtype=GLOBAL_TF_FLOAT_PRECISION.as_numpy_dtype, )source/tests/tf/test_lr.py (1)
87-104: Consider consolidating withtest_tensor_value_matches_base_lr.This test is functionally identical to
test_tensor_value_matches_base_lr—both verify thatvalue()matchesbase_lr.value()afterbuild(). Consider consolidating them to reduce duplication, or keep both if the intent is to emphasize different aspects (post-build accessibility vs. value correctness).
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/tf/utils/learning_rate.pysource/tests/tf/test_lr.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/tf/utils/learning_rate.py
📚 Learning: 2025-12-12T13:40:14.334Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-12-12T13:40:14.334Z
Learning: Run core tests with `pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame -v` to validate basic functionality
Applied to files:
source/tests/tf/test_lr.py
🧬 Code graph analysis (2)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
BaseLR(25-191)value(148-191)
source/tests/tf/test_lr.py (2)
deepmd/dpmodel/utils/learning_rate.py (2)
LearningRateExp(195-365)value(148-191)deepmd/tf/utils/learning_rate.py (5)
LearningRateSchedule(21-131)value(110-131)base_lr(51-67)build(69-108)start_lr(39-48)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py
66-66: Avoid specifying long messages outside the exception class
(TRY003)
130-130: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (38)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (9)
deepmd/tf/utils/learning_rate.py (4)
2-18: LGTM!Imports are well-organized. The
GLOBAL_TF_FLOAT_PRECISIONimport andBaseLRdelegation align with the refactored architecture.
21-48: LGTM!The docstring appropriately documents the
tf.numpy_functionperformance consideration. Thestart_lr()accessor is clean and straightforward.
110-131: LGTM!The
value()method correctly delegates toBaseLR.value()and handles the built-check. The same TRY003 consideration applies to line 130 as noted earlier.
134-136: LGTM!Clean export of the renamed public class.
source/tests/tf/test_lr.py (5)
1-20: LGTM!Clear module docstring establishing test scope. Good separation of concerns—TF wrapper logic is tested here while core algorithms are tested in dpmodel tests.
23-38: LGTM!Good coverage of pre-build error handling. The assertions verify both the exception type and message content.
53-59: LGTM!Good verification that the default scheduler type resolves to
LearningRateExp.
82-85: LGTM!Clean test for the
start_lr()accessor.
107-108: LGTM!Standard test runner entry point.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| self.optimizer, | ||
| lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), | ||
| lambda step: self.lr_exp.value(step + self.start_step) | ||
| / self.lr_exp.start_lr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest providing a method for accessing the start_lr, rather than directly reads the data of the object.
Summary by CodeRabbit
New Features
Improvements
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.