Skip to content

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Jan 15, 2026

  • Refactor BaseLR in dpmodel to use array_api_compat for backend-agnostic implementation
  • Consolidate learning rate logic from TF/PT/PD backends into unified dpmodel layer
  • Use array API operations (xp.where, xp.clip, etc.) for JIT compatibility across backends
  • Add warmup support (warmup_steps, warmup_ratio, warmup_start_factor) during refactoring
  • Add stop_ratio parameter as alternative to stop_lr for flexible configuration
  • Implement mutual exclusion validation for stop_lr/stop_ratio and warmup_steps/warmup_ratio
  • Update all backends to use unified BaseLR implementation
  • Add comprehensive consistency tests across NumPy/PyTorch/JAX/array_api_strict backends

Summary by CodeRabbit

  • New Features

    • Optional warmup phase for learning-rate schedules (steps/ratio/start factor).
    • Cosine annealing schedule and a TensorFlow-friendly learning-rate wrapper.
    • stop_lr_ratio to specify final LR as a ratio.
  • Improvements

    • Unified LR behavior across backends with stricter validations and clearer error checks.
    • Numeric inputs to array conversion now accepted.
  • Documentation

    • Training guide updated with two-phase LR model and examples.
  • Tests

    • Expanded LR tests covering warmup, cosine, ratios, and edge cases.

✏️ Tip: You can customize this high-level summary in your review settings.

Copilot AI review requested due to automatic review settings January 15, 2026 04:27
@dosubot dosubot bot added the enhancement label Jan 15, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Refactors learning-rate subsystem: adds BaseLR with warmup support and an abstract _decay_value/value contract, updates exponential and cosine schedulers, introduces stop_lr_ratio/num_steps, converts TF wrapper to LearningRateSchedule (dict-based), updates training/scheduler usage, adds validations, docs, examples, and tests.

Changes

Cohort / File(s) Summary
Core LR implementation
deepmd/dpmodel/utils/learning_rate.py
Introduces BaseLR with warmup support, abstract _decay_value(step) and public value(step). LearningRateExp and LearningRateCosine now implement _decay_value, accept stop_lr/stop_lr_ratio, num_steps, warmup params, compute decay_num_steps, clamp min_lr, and add runtime validations.
TF wrapper & exports
deepmd/tf/utils/learning_rate.py, deepmd/tf/utils/__init__.py
Renames LearningRateExpLearningRateSchedule; wrapper accepts params: dict, lazily builds a BaseLR backend in build(global_step, num_steps) via numpy_function, exposes start_lr(), base_lr property, and value(step); exports updated.
Argument validation / CLI args
deepmd/utils/argcheck.py
Adds _check_lr_stop_args, _check_warmup_args, _check_decay_steps_args, and _learning_rate_common_args; learning-rate arg construction unified and wired with extra_check validations.
Training code (pd / pt / tf)
deepmd/pd/train/training.py, deepmd/pt/train/training.py, deepmd/tf/train/trainer.py
Switches lr params to use num_steps (replacing stop_steps), removes warmup-specific branching in scheduler lambdas, and makes LR lambdas uniformly call lr_exp.value(step+start_step)/lr_exp.start_lr; TF trainer type hints and LR construction updated to LearningRateSchedule.
Utilities: array conversion
deepmd/pd/utils/utils.py, deepmd/pt/utils/utils.py
to_numpy_array now accepts scalar numeric and np.ndarray, casts non-tensor inputs to NumPy with configured precision, and adds overloads/type-hint updates.
Tests: unit & integration
source/tests/universal/dpmodel/utils/test_learning_rate.py, source/tests/tf/test_lr.py, source/tests/*/test_lr.py, source/tests/consistent/test_learning_rate.py, updated test_model files
Adds comprehensive LR unit tests (exp/cosine, warmup, array input, beyond-num-steps), TF wrapper tests, and updates consistency/integration tests to use LearningRateSchedule, num_steps, and warmup checks.
Docs & examples
doc/train/training-advanced.md, examples/**/input*.json, many source/tests/**/*.json
Docs updated for two-phase schedule (warmup + decay), expanded math/examples, many example JSONs reformatted and warmup_steps removed or adjusted; some stop_lr fields added/normalized.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant TFWrapper as LearningRateSchedule
    participant BaseLR
    participant NumpyFunc as numpy_function

    Trainer->>TFWrapper: build(global_step, num_steps)
    TFWrapper->>BaseLR: instantiate BaseLR(params + num_steps)
    TFWrapper->>NumpyFunc: register callable -> BaseLR.value
    Trainer->>NumpyFunc: runtime global_step -> call BaseLR.value(step)
    NumpyFunc->>BaseLR: value(step)
    alt step < warmup_steps
        BaseLR->>BaseLR: compute warmup lr (linear interp)
    else
        BaseLR->>BaseLR: _decay_value(step - warmup_steps) -> exp/cosine
    end
    BaseLR-->>NumpyFunc: float lr
    NumpyFunc-->>Trainer: lr tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

new feature

Suggested reviewers

  • njzjz
  • wanghan-iapcm
  • iProzd
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.45% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'refactor: unify learning rate schedulers with array API' accurately reflects the main objective of consolidating learning-rate logic across backends into a unified BaseLR using array_api_compat, which is the primary change.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings


📜 Recent review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 947e8a6 and ad40b00.

📒 Files selected for processing (3)
  • source/tests/pd/model/test_model.py
  • source/tests/pt/model/test_model.py
  • source/tests/tf/test_lr.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/tf/test_lr.py
🧰 Additional context used
🧬 Code graph analysis (2)
source/tests/pt/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (3)
  • LearningRateSchedule (21-131)
  • build (69-108)
  • start_lr (39-48)
source/tests/pd/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (3)
  • LearningRateSchedule (21-131)
  • build (69-108)
  • start_lr (39-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (11)
source/tests/pt/model/test_model.py (6)

51-53: LGTM: new schedule import aligns with refactor.


108-112: Good: explicit num_steps for scheduling.


139-141: LGTM: build now uses total steps.


228-235: LGTM: dict-based schedule config is consistent.


244-248: LGTM.


306-311: No action needed. The num_steps parameter is fully supported by LearningRateExp (imported as MyLRExp), which has num_steps: int = 100000 in its constructor signature. The call is valid and will not raise a TypeError.

Likely an incorrect or invalid review comment.

source/tests/pd/model/test_model.py (5)

51-53: LGTM: new schedule import aligns with refactor.


108-112: Good: explicit num_steps for scheduling.


139-141: LGTM: build now uses total steps.


228-235: LGTM: dict-based schedule config is consistent.


244-248: LGTM.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/utils/learning_rate.py`:
- Around line 336-338: BaseLR.__init__ currently computes self.lr_min_factor =
self.stop_lr / self.start_lr which will raise ZeroDivisionError if start_lr is
zero; add an explicit validation at the start of BaseLR.__init__ that checks
start_lr is non-zero (and preferably > 0) and raise a clear ValueError like
"start_lr must be > 0" if it fails, so the invalid configuration is caught with
a descriptive message before any division occurs; update callers/tests if they
relied on a zero value.
🧹 Nitpick comments (5)
deepmd/tf/fit/polar.py (1)

859-867: LGTM! Docstring updated to reflect the renamed scheduler class.

The documentation change from LearningRateExp to LearningRateSchedule correctly aligns with the PR's refactoring of the learning rate scheduler API.

Consider adding a type annotation to the function signature for consistency:

-    def get_loss(self, loss: dict, lr) -> Loss:
+    def get_loss(self, loss: dict, lr: "LearningRateSchedule") -> Loss:

This would also apply to PolarFittingSeA.get_loss at line 618, which currently lacks parameter documentation entirely.

source/tests/consistent/test_learning_rate.py (1)

74-78: Redundant skip check inside compare_test_with_warmup_ref.

The skipTest guard on line 75-76 is redundant since all call sites already check if self.warmup_step is not None before invoking this method. Consider removing the internal check or, if keeping it as defensive coding, note that warmup_ref being None while warmup_step is set would indicate a bug in setUp.

♻️ Suggested simplification
     def compare_test_with_warmup_ref(self, step: Array) -> None:
-        if self.warmup_ref is None:
-            self.skipTest("warmup not enabled")
+        assert self.warmup_ref is not None, "warmup_ref should be set when warmup_step is set"
         test = self.lr.value(step)
         np.testing.assert_allclose(self.warmup_ref, to_numpy_array(test), atol=1e-10)
deepmd/pt/utils/utils.py (1)

229-237: Type hint mismatch: int is handled but not declared in signature.

The function signature at line 230 declares torch.Tensor | np.ndarray | float | None, but line 234 also handles int inputs. Additionally, the @overload signatures (lines 221-227) don't cover the new input types (float, int, np.ndarray), which may cause type checker warnings for callers.

Proposed fix to align type hints
 `@overload`
 def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ...
 
 
 `@overload`
 def to_numpy_array(xx: None) -> None: ...
 
+
+@overload
+def to_numpy_array(xx: np.ndarray) -> np.ndarray: ...
+
+
+@overload
+def to_numpy_array(xx: float | int) -> np.ndarray: ...
+
 
 def to_numpy_array(
-    xx: torch.Tensor | np.ndarray | float | None,
+    xx: torch.Tensor | np.ndarray | float | int | None,
 ) -> np.ndarray | None:
deepmd/pd/utils/utils.py (1)

250-253: Note: Pre-existing dead code.

Line 251 provides a default value (np.float64) when the key is not found, so prec can never be None at line 252. The check at lines 252-253 is unreachable. This is pre-existing and unrelated to the current changes, but worth noting for future cleanup.

deepmd/dpmodel/utils/learning_rate.py (1)

72-75: Potential type inconsistency when both stop_lr and stop_ratio are None.

If validation in argcheck.py is bypassed (e.g., programmatic instantiation without going through argument normalization), self.stop_lr could remain None, which would cause issues downstream (e.g., line 246 max(self.stop_lr, 1e-10) or line 338 self.stop_lr / self.start_lr).

Consider adding a defensive runtime check here, or ensure documentation clearly states that one of stop_lr or stop_ratio must be provided.

♻️ Suggested defensive check
         # === Step 1. Compute stop_lr from stop_ratio if needed ===
         # Mutual exclusion validated in argcheck.py
         if stop_ratio is not None:
             self.stop_lr = start_lr * stop_ratio
         else:
             self.stop_lr = stop_lr  # type: ignore[assignment]
+
+        if self.stop_lr is None:
+            raise ValueError(
+                "Either stop_lr or stop_ratio must be provided"
+            )
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2a9667e and 89cdf56.

📒 Files selected for processing (21)
  • deepmd/dpmodel/utils/learning_rate.py
  • deepmd/pd/train/training.py
  • deepmd/pd/utils/utils.py
  • deepmd/pt/train/training.py
  • deepmd/pt/utils/utils.py
  • deepmd/tf/fit/dipole.py
  • deepmd/tf/fit/dos.py
  • deepmd/tf/fit/ener.py
  • deepmd/tf/fit/fitting.py
  • deepmd/tf/fit/polar.py
  • deepmd/tf/train/trainer.py
  • deepmd/tf/utils/__init__.py
  • deepmd/tf/utils/learning_rate.py
  • deepmd/utils/argcheck.py
  • source/tests/consistent/test_learning_rate.py
  • source/tests/pd/model/test_model.py
  • source/tests/pd/test_lr.py
  • source/tests/pt/model/test_model.py
  • source/tests/pt/test_lr.py
  • source/tests/tf/test_lr.py
  • source/tests/universal/dpmodel/utils/test_learning_rate.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/pd/utils/utils.py
  • deepmd/tf/fit/fitting.py
📚 Learning: 2025-12-12T13:40:14.334Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-12-12T13:40:14.334Z
Learning: Run core tests with `pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame -v` to validate basic functionality

Applied to files:

  • source/tests/tf/test_lr.py
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR `#3905`.

Applied to files:

  • source/tests/pt/test_lr.py
🧬 Code graph analysis (11)
deepmd/pt/utils/utils.py (1)
deepmd/pt/model/network/network.py (1)
  • Tensor (34-35)
source/tests/tf/test_lr.py (2)
deepmd/dpmodel/utils/learning_rate.py (2)
  • LearningRateExp (170-278)
  • value (123-166)
deepmd/tf/utils/learning_rate.py (5)
  • LearningRateSchedule (20-123)
  • value (102-123)
  • base_lr (50-66)
  • build (68-100)
  • start_lr (38-47)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
  • BaseLR (25-166)
  • value (123-166)
source/tests/pd/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (2)
  • LearningRateSchedule (20-123)
  • start_lr (38-47)
deepmd/pt/train/training.py (2)
deepmd/pd/train/training.py (1)
  • step (722-963)
deepmd/tf/utils/learning_rate.py (2)
  • value (102-123)
  • start_lr (38-47)
source/tests/universal/dpmodel/utils/test_learning_rate.py (4)
deepmd/pd/utils/utils.py (3)
  • to_numpy_array (230-230)
  • to_numpy_array (234-234)
  • to_numpy_array (237-256)
deepmd/pt/utils/utils.py (3)
  • to_numpy_array (222-222)
  • to_numpy_array (226-226)
  • to_numpy_array (229-249)
deepmd/dpmodel/utils/learning_rate.py (3)
  • LearningRateCosine (282-374)
  • LearningRateExp (170-278)
  • value (123-166)
deepmd/tf/utils/learning_rate.py (2)
  • start_lr (38-47)
  • value (102-123)
deepmd/tf/train/trainer.py (1)
deepmd/tf/utils/learning_rate.py (1)
  • LearningRateSchedule (20-123)
deepmd/dpmodel/utils/learning_rate.py (2)
deepmd/tf/utils/learning_rate.py (1)
  • start_lr (38-47)
deepmd/utils/argcheck.py (1)
  • register (185-215)
deepmd/tf/utils/__init__.py (1)
deepmd/tf/utils/learning_rate.py (1)
  • LearningRateSchedule (20-123)
source/tests/pt/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (2)
  • LearningRateSchedule (20-123)
  • start_lr (38-47)
source/tests/pd/test_lr.py (2)
deepmd/tf/utils/learning_rate.py (4)
  • LearningRateSchedule (20-123)
  • start_lr (38-47)
  • base_lr (50-66)
  • build (68-100)
deepmd/dpmodel/utils/learning_rate.py (1)
  • LearningRateExp (170-278)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py

34-34: Avoid specifying long messages outside the exception class

(TRY003)


65-65: Avoid specifying long messages outside the exception class

(TRY003)


122-122: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/dpmodel/utils/learning_rate.py

40-40: Unused method argument: kwargs

(ARG002)


86-86: Avoid specifying long messages outside the exception class

(TRY003)


88-88: Avoid specifying long messages outside the exception class

(TRY003)


90-90: Avoid specifying long messages outside the exception class

(TRY003)


241-243: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/utils/argcheck.py

2506-2509: Avoid specifying long messages outside the exception class

(TRY003)


2511-2514: Avoid specifying long messages outside the exception class

(TRY003)


2542-2545: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (45)
source/tests/consistent/test_learning_rate.py (2)

56-68: LGTM!

The test setup correctly initializes reference values for both regular steps and warmup steps, enabling cross-backend consistency validation.


83-104: LGTM!

The test methods consistently validate learning rate values across PyTorch, array_api_strict, and JAX backends, with proper conditional warmup testing.

deepmd/tf/fit/ener.py (1)

852-866: LGTM!

The docstring type annotation update to LearningRateSchedule aligns with the broader refactoring in this PR.

deepmd/tf/fit/fitting.py (1)

68-83: LGTM!

The docstring type annotation update in the abstract get_loss method correctly reflects the new LearningRateSchedule type used throughout the codebase.

deepmd/tf/fit/dipole.py (1)

384-398: LGTM!

The docstring type annotation update to LearningRateSchedule is consistent with the base class and other fitting implementations.

deepmd/tf/fit/dos.py (1)

651-668: LGTM!

The docstring type annotation update to LearningRateSchedule is consistent with the base class and other fitting implementations.

source/tests/pt/model/test_model.py (2)

51-53: LGTM!

The import change from LearningRateExp to LearningRateSchedule correctly aligns with the refactored TensorFlow learning rate API. The test appropriately uses separate LR implementations for each backend being compared.


228-236: LGTM!

The migration to dict-based LearningRateSchedule construction is correct and provides all required parameters. This aligns with the unified API that accepts a configuration dictionary.

deepmd/pd/utils/utils.py (1)

242-246: LGTM!

The added handling for scalar inputs and numpy arrays mirrors the changes in the PT backend, maintaining consistency across the codebase.

deepmd/tf/utils/__init__.py (2)

20-27: LGTM!

The __all__ list correctly reflects the updated export, maintaining consistency between the module's imports and its public interface.


9-11: This concern is not applicable. The search shows no code imports LearningRateExp from deepmd.tf.utils, and the test file explicitly imports LearningRateExp from deepmd.dpmodel.utils.learning_rate. The LearningRateSchedule exported from deepmd.tf.utils is the TensorFlow-specific wrapper class, not a renamed version of a previously exported symbol. This is the correct architecture: core learning rate algorithms (LearningRateExp) reside in dpmodel, while framework-specific wrappers (LearningRateSchedule) reside in tf/pt/pd submodules.

Likely an incorrect or invalid review comment.

source/tests/pd/model/test_model.py (2)

51-53: LGTM!

The import correctly uses LearningRateSchedule for the TensorFlow side of the consistency test, aligning with the refactored API.


228-236: LGTM!

The dict-based LearningRateSchedule construction is correct and consistent with the PT test file, providing all required configuration parameters for the exponential decay schedule.

source/tests/tf/test_lr.py (3)

1-21: LGTM! Well-structured test file with proper imports and clear documentation.

The test module correctly separates validation tests from build/integration tests, and the docstring clearly states that core algorithm tests are in dpmodel tests.


23-44: LGTM! Validation tests properly verify error handling.

The tests correctly validate that ValueError is raised for missing start_lr and RuntimeError is raised when accessing value() or base_lr before build() is called.


47-110: LGTM! Build and integration tests are comprehensive.

The tests properly validate:

  • Tensor output type and dtype
  • Default scheduler type inference
  • Value consistency between TF tensor and BaseLR.value()
  • Accessor methods work correctly before and after build
deepmd/pt/train/training.py (3)

275-278: LGTM! Correct stop_steps assignment.

Setting stop_steps to self.num_steps aligns with the unified BaseLR interface where warmup is now handled internally.


698-702: LGTM! Lambda LR calculation is correct.

The lambda correctly computes the learning rate multiplier by dividing the absolute LR value by start_lr. The step + self.start_step offset properly handles training resumption.


797-802: LGTM! Simplified pref_lr assignment.

With warmup logic now handled internally by BaseLR, the training loop correctly uses the current learning rate directly without conditional warmup branching.

source/tests/pt/test_lr.py (4)

22-24: LGTM! Corrected decay_steps range.

The constraint decay_steps ∈ [400, 500] with stop_steps ∈ [500, 1500] ensures decay_steps never exceeds stop_steps, which would otherwise raise a ValueError in LearningRateExp.


35-42: LGTM! Correct migration to dictionary-based construction.

The LearningRateSchedule constructor now uses a dictionary with explicit type specification, aligning with the unified API.


77-83: LGTM! Good refactor to use local variable.

Using decay_step_for_rate instead of mutating self.decay_step prevents side effects that could affect subsequent test iterations.


121-139: LGTM! Cosine scheduler test correctly migrated.

The test validates basic cosine decay behavior including start, end, and midpoint values with the new keyword-argument API.

deepmd/tf/train/trainer.py (2)

106-118: LGTM! Clean refactor to use LearningRateSchedule.

The helper function correctly:

  1. Extracts and handles scale_by_worker separately
  2. Passes filtered params to LearningRateSchedule
  3. Returns properly typed tuple

429-434: LGTM! Simplified logging message.

The log message now focuses on the essential learning rate values (start, current, final) without exposing internal decay parameters, which aligns with the abstracted scheduler interface.

deepmd/pd/train/training.py (3)

241-244: LGTM! Consistent with PyTorch implementation.

The get_lr function follows the same pattern as the PyTorch trainer, setting stop_steps to total steps and using BaseLR directly.


582-586: LGTM! Paddle LambdaDecay correctly configured.

The lambda function matches the PyTorch pattern, computing the LR multiplier as lr_exp.value(step + start_step) / lr_exp.start_lr for proper LR scheduling with resume support.


747-749: LGTM! Simplified pref_lr assignment consistent with PyTorch.

The warmup-free pref_lr = cur_lr assignment aligns with the unified approach where warmup is handled internally by BaseLR.

deepmd/tf/utils/learning_rate.py (2)

30-36: LGTM on constructor design.

The constructor correctly stores the configuration dictionary and validates the required start_lr parameter. The lazy initialization of _base_lr is appropriate since stop_steps is only known at build time.


68-100: Consider potential performance implications of tf.numpy_function.

Using tf.numpy_function wraps Python/NumPy execution, which works correctly but has implications:

  1. It breaks TF graph optimization and XLA compilation
  2. The function runs on CPU even in GPU-enabled sessions

This is acceptable for learning rate computation (low-frequency, scalar operation), but worth noting in documentation if performance-critical use cases arise.

source/tests/pd/test_lr.py (3)

21-23: Good adjustment to test parameter ranges.

The comment correctly notes that decay_steps must not exceed stop_steps. The range np.arange(400, 501, 100) produces [400, 500], which correctly stays within the minimum stop_steps of 500.


34-41: Correct dict-based construction for LearningRateSchedule.

The test properly uses the new dictionary-based API for LearningRateSchedule, matching the expected interface where stop_steps is provided during build().


76-83: Good refactor to use local variable.

Using decay_step_for_rate instead of modifying self.decay_step avoids unintended state mutation between test iterations. This is a clean improvement.

source/tests/universal/dpmodel/utils/test_learning_rate.py (4)

18-27: LGTM on basic decay test.

The test correctly verifies:

  • Initial LR at step 0 equals start_lr (1e-3)
  • Final LR at step 10000 equals stop_lr (1e-5)

The tolerances (rtol=1e-10 and rtol=1e-5) are appropriate for the precision requirements.


80-93: Verify warmup boundary behavior at exactly warmup_steps.

The test at line 92 checks lr.value(1000) equals 1e-3, but step 1000 is exactly at warmup_steps=1000. Based on the implementation in BaseLR.value() (line 162 in learning_rate.py), the condition is step < self.warmup_steps, so step 1000 would enter the decay phase, not warmup.

At step 1000 (decay_step=0), the decay phase returns start_lr since decay_rate^0 = 1. So the assertion is correct, but the test comment could clarify this boundary behavior.


147-163: Good coverage for array/vectorized inputs.

Testing array inputs ensures JIT compatibility across backends. The assertions correctly verify:

  • Shape preservation (lrs.shape == (5,))
  • Warmup at step 0 returns warmup_start_lr (0.0)
  • End of warmup at step 1000 returns start_lr (1e-3)

207-217: Good validation test for decay_steps constraint.

The test correctly verifies that decay_steps > stop_steps raises ValueError with appropriate message content checks.

deepmd/utils/argcheck.py (4)

2483-2515: Correct mutual exclusion validation for stop_lr/stop_ratio.

The validation correctly:

  1. Checks if both are provided (error)
  2. Checks if neither is provided (error)
  3. Returns True on success

The error messages are clear and informative, which outweighs the TRY003 style concern.


2536-2546: Correct warmup validation logic.

The check correctly identifies non-zero warmup_steps (since default is 0) and non-None warmup_ratio as the conditions for mutual exclusion. This aligns with the behavior where warmup_steps=0 means "no warmup" rather than "explicit zero warmup steps."


2549-2619: Well-designed common argument builder.

The _learning_rate_common_args function effectively reduces code duplication by:

  1. Building common args (start_lr, stop_lr, stop_ratio)
  2. Accepting extra_args for scheduler-specific parameters
  3. Appending warmup-related args at the end

This design allows learning_rate_exp to inject decay_steps and decay_rate while sharing the rest.


2690-2714: Good integration of validation checks.

The nested _check_lr_args function composes both validation checks and is correctly passed to extra_check. This ensures validation runs during argument normalization.

deepmd/dpmodel/utils/learning_rate.py (4)

31-99: Solid base class design with warmup support.

The BaseLR.__init__ correctly:

  1. Computes stop_lr from stop_ratio if provided
  2. Computes warmup_steps from warmup_ratio if provided
  3. Validates step ranges at runtime
  4. Computes derived values (warmup_start_lr, decay_stop_steps)

Note: The kwargs parameter (flagged by static analysis) is necessary for the plugin system's **kwargs forwarding pattern.


137-166: Correct warmup and decay phase implementation.

The value() method correctly:

  1. Handles scalar vs array inputs
  2. Uses array API for backend-agnostic operations
  3. Computes linear warmup interpolation
  4. Delegates decay computation to subclass _decay_value()
  5. Uses xp.where for JIT-compatible branching

The use of xp.maximum(..., 0.0) for decay_step ensures negative values don't propagate during warmup phase.


240-250: Good validation and numerical stability handling.

The implementation correctly:

  1. Validates decay_steps <= decay_total to prevent invalid configurations
  2. Clamps stop_lr to 1e-10 for log computation to avoid log(0)
  3. Stores original stop_lr as min_lr for clamping

This ensures numerical stability while preserving the intended final learning rate behavior.


364-373: Correct cosine decay with boundary handling.

The cosine annealing formula is correctly implemented:

  • Uses decay_stop_steps instead of stop_steps to account for warmup
  • Clamps to min_lr for steps beyond the decay phase using xp.where

This ensures the learning rate doesn't oscillate beyond the intended schedule.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request refactors the learning rate schedulers to use array_api_compat for backend-agnostic implementation, consolidating logic from TensorFlow, PyTorch, and PaddlePaddle backends into a unified dpmodel layer. The refactoring adds warmup support and flexible configuration options while ensuring JIT compatibility across backends.

Changes:

  • Unified learning rate scheduler implementation in dpmodel using array API operations
  • Added warmup functionality (warmup_steps, warmup_ratio, warmup_start_factor) with mutual exclusion validation
  • Added stop_ratio parameter as alternative to stop_lr with mutual exclusion validation
  • Updated all backend implementations to use the unified BaseLR, removing duplicated warmup logic
  • Added comprehensive tests for core functionality and cross-backend consistency

Reviewed changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
deepmd/dpmodel/utils/learning_rate.py Core unified learning rate implementation with warmup and array API support
deepmd/utils/argcheck.py Updated argument definitions with warmup parameters and mutual exclusion validation
deepmd/tf/utils/learning_rate.py Simplified TensorFlow wrapper to use unified BaseLR
deepmd/tf/train/trainer.py Updated to use new LearningRateSchedule API
deepmd/tf/fit/*.py Updated type annotations for learning rate parameter
deepmd/pt/train/training.py Removed local warmup implementation, delegated to BaseLR
deepmd/pt/utils/utils.py Extended to_numpy_array to handle scalars and numpy arrays
deepmd/pd/train/training.py Removed local warmup implementation, delegated to BaseLR
deepmd/pd/utils/utils.py Extended to_numpy_array to handle scalars and numpy arrays
source/tests/universal/dpmodel/utils/test_learning_rate.py New comprehensive tests for core learning rate functionality
source/tests/tf/test_lr.py New tests for TensorFlow wrapper
source/tests/pt/test_lr.py Updated tests to use keyword arguments
source/tests/pd/test_lr.py Updated tests to use keyword arguments
source/tests/consistent/test_learning_rate.py Enhanced consistency tests with warmup validation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@OutisLi
Copy link
Collaborator Author

OutisLi commented Jan 15, 2026

TODO:

  • Adjust input.json in example files
  • Change docs

@OutisLi OutisLi added the breaking change Breaking changes that should notify users. label Jan 15, 2026
- Refactor BaseLR in dpmodel to use array_api_compat for backend-agnostic implementation
- Consolidate learning rate logic from TF/PT/PD backends into unified dpmodel layer
- Use array API operations (xp.where, xp.clip, etc.) for JIT compatibility across backends
- Add warmup support (warmup_steps, warmup_ratio, warmup_start_factor) during refactoring
- Add stop_ratio parameter as alternative to stop_lr for flexible configuration
- Implement mutual exclusion validation for stop_lr/stop_ratio and warmup_steps/warmup_ratio
- Update all backends to use unified BaseLR implementation
- Add comprehensive consistency tests across NumPy/PyTorch/JAX/array_api_strict backends
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/utils/learning_rate.py`:
- Around line 194-277: The constructor for the exponential scheduler (__init__
in LearningRateExp) currently allows decay_steps <= 0 which leads to
division-by-zero; add an early validation after reading decay_steps (and before
using it in decay rate computation and in _decay_value) to raise ValueError if
decay_steps <= 0 with a clear message like "decay_steps must be > 0"; ensure
this check is performed before comparing to decay_total and before any math that
divides by decay_steps so both the np.log(...) / (decay_total /
self.decay_steps) computation and the floor-division in _decay_value are
protected.

In `@deepmd/tf/train/trainer.py`:
- Around line 106-118: When building the learning-rate schedule in
DPTrainer.build (and the same logic around lines 242-246), guard the case
stop_batch == 0 and avoid calling LearningRateSchedule.build which passes
num_steps=0 into BaseLR; instead short-circuit and return a constant LR tensor
or a trivial LearningRateSchedule that does not call BaseLR.build. Concretely:
detect stop_batch == 0 before calling
LearningRateSchedule.build(self.global_step, self.stop_batch), and in that
branch create a fixed scalar/constant schedule (or skip build and set lr to a
constant tensor) so BaseLR validation (num_steps > 0) is not triggered.
🧹 Nitpick comments (5)
deepmd/pt/utils/utils.py (1)

221-238: Overload signatures are incomplete for new input types.

The function signature now accepts torch.Tensor | np.ndarray | float | None, but the @overload declarations only cover torch.Tensor and None. This may cause type checkers to flag valid calls with np.ndarray or float arguments as errors.

♻️ Proposed fix to add missing overloads
 `@overload`
 def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ...


 `@overload`
 def to_numpy_array(xx: None) -> None: ...


+@overload
+def to_numpy_array(xx: np.ndarray) -> np.ndarray: ...
+
+
+@overload
+def to_numpy_array(xx: float) -> np.ndarray: ...
+
+
 def to_numpy_array(
     xx: torch.Tensor | np.ndarray | float | None,
 ) -> np.ndarray | None:
deepmd/pd/train/training.py (1)

241-245: Avoid mutating lr_params in-place when injecting num_steps.

This currently mutates the caller’s config dict (and, for multi-task, potentially multiple shared dicts). Prefer copying before modification.

Proposed fix
         def get_lr(lr_params: dict[str, Any]) -> BaseLR:
-            lr_params["num_steps"] = self.num_steps
-            lr_schedule = BaseLR(**lr_params)
+            _params = dict(lr_params)
+            _params["num_steps"] = self.num_steps
+            lr_schedule = BaseLR(**_params)
             return lr_schedule
source/tests/pd/model/test_model.py (1)

306-306: Consider using keyword arguments for consistency.

This instantiation uses positional arguments while the codebase (including _get_dp_lr at line 229-236 and the PT test file) uses keyword arguments. Consider using keyword arguments for clarity and consistency with the new API style.

♻️ Suggested change
-        my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.num_steps)
+        my_lr = MyLRExp(
+            start_lr=self.start_lr,
+            stop_lr=self.stop_lr,
+            decay_steps=self.decay_steps,
+            num_steps=self.num_steps,
+        )
source/tests/pt/model/test_model.py (1)

306-306: Consider using keyword arguments for consistency.

This instantiation uses positional arguments while the codebase (including _get_dp_lr at lines 229-236 and source/tests/pt/test_lr.py) uses keyword arguments. Consider using keyword arguments for clarity and consistency with the new API style.

♻️ Suggested change
-        my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.num_steps)
+        my_lr = MyLRExp(
+            start_lr=self.start_lr,
+            stop_lr=self.stop_lr,
+            decay_steps=self.decay_steps,
+            num_steps=self.num_steps,
+        )
source/tests/universal/dpmodel/utils/test_learning_rate.py (1)

204-240: Add tests for mutual exclusion parameter validation.

The implementation validates that stop_lr and stop_lr_rate are mutually exclusive, and that warmup_steps and warmup_ratio are mutually exclusive. Consider adding test cases to verify these validations raise appropriate errors:

  • Both stop_lr and stop_lr_rate provided
  • Both warmup_steps and warmup_ratio provided
  • Neither parameter provided in each pair (if applicable)
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 89cdf56 and 84f3953.

📒 Files selected for processing (21)
  • deepmd/dpmodel/utils/learning_rate.py
  • deepmd/pd/train/training.py
  • deepmd/pd/utils/utils.py
  • deepmd/pt/train/training.py
  • deepmd/pt/utils/utils.py
  • deepmd/tf/fit/dipole.py
  • deepmd/tf/fit/dos.py
  • deepmd/tf/fit/ener.py
  • deepmd/tf/fit/fitting.py
  • deepmd/tf/fit/polar.py
  • deepmd/tf/train/trainer.py
  • deepmd/tf/utils/__init__.py
  • deepmd/tf/utils/learning_rate.py
  • deepmd/utils/argcheck.py
  • source/tests/consistent/test_learning_rate.py
  • source/tests/pd/model/test_model.py
  • source/tests/pd/test_lr.py
  • source/tests/pt/model/test_model.py
  • source/tests/pt/test_lr.py
  • source/tests/tf/test_lr.py
  • source/tests/universal/dpmodel/utils/test_learning_rate.py
🚧 Files skipped from review as they are similar to previous changes (8)
  • deepmd/tf/fit/fitting.py
  • source/tests/consistent/test_learning_rate.py
  • deepmd/tf/fit/dipole.py
  • source/tests/tf/test_lr.py
  • deepmd/tf/fit/ener.py
  • deepmd/tf/fit/polar.py
  • deepmd/pd/utils/utils.py
  • source/tests/pd/test_lr.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR `#3905`.

Applied to files:

  • source/tests/pt/test_lr.py
🧬 Code graph analysis (11)
deepmd/pt/utils/utils.py (1)
deepmd/pt/model/network/network.py (1)
  • Tensor (34-35)
deepmd/tf/train/trainer.py (1)
deepmd/tf/utils/learning_rate.py (1)
  • LearningRateSchedule (20-123)
deepmd/tf/utils/__init__.py (1)
deepmd/tf/utils/learning_rate.py (1)
  • LearningRateSchedule (20-123)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
  • BaseLR (25-189)
  • value (146-189)
source/tests/pd/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (3)
  • LearningRateSchedule (20-123)
  • build (68-100)
  • start_lr (38-47)
deepmd/dpmodel/utils/learning_rate.py (3)
deepmd/tf/utils/learning_rate.py (2)
  • start_lr (38-47)
  • value (102-123)
deepmd/pd/train/training.py (1)
  • step (722-963)
deepmd/pt/train/training.py (1)
  • step (773-1142)
source/tests/pt/test_lr.py (2)
deepmd/tf/utils/learning_rate.py (5)
  • LearningRateSchedule (20-123)
  • start_lr (38-47)
  • base_lr (50-66)
  • build (68-100)
  • value (102-123)
deepmd/dpmodel/utils/learning_rate.py (2)
  • LearningRateExp (193-301)
  • value (146-189)
deepmd/pd/train/training.py (3)
deepmd/pt/train/training.py (1)
  • step (773-1142)
deepmd/dpmodel/utils/learning_rate.py (1)
  • value (146-189)
deepmd/tf/utils/learning_rate.py (2)
  • value (102-123)
  • start_lr (38-47)
source/tests/universal/dpmodel/utils/test_learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (3)
  • LearningRateCosine (305-397)
  • LearningRateExp (193-301)
  • value (146-189)
deepmd/utils/argcheck.py (1)
deepmd/utils/plugin.py (2)
  • register (41-59)
  • register (122-141)
source/tests/pt/model/test_model.py (1)
deepmd/tf/utils/learning_rate.py (3)
  • LearningRateSchedule (20-123)
  • build (68-100)
  • start_lr (38-47)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py

34-34: Avoid specifying long messages outside the exception class

(TRY003)


65-65: Avoid specifying long messages outside the exception class

(TRY003)


122-122: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/dpmodel/utils/learning_rate.py

40-40: Unused method argument: kwargs

(ARG002)


75-78: Avoid specifying long messages outside the exception class

(TRY003)


80-83: Avoid specifying long messages outside the exception class

(TRY003)


96-99: Avoid specifying long messages outside the exception class

(TRY003)


109-109: Avoid specifying long messages outside the exception class

(TRY003)


111-111: Avoid specifying long messages outside the exception class

(TRY003)


113-113: Avoid specifying long messages outside the exception class

(TRY003)


264-266: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/utils/argcheck.py

2506-2509: Avoid specifying long messages outside the exception class

(TRY003)


2511-2514: Avoid specifying long messages outside the exception class

(TRY003)


2542-2545: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (31)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (36)
deepmd/tf/utils/__init__.py (1)

9-11: LGTM!

The public export correctly updates from LearningRateExp to LearningRateSchedule, aligning with the unified LR scheduler refactoring. This is a breaking change as noted in the PR labels.

Also applies to: 20-27

deepmd/pt/utils/utils.py (1)

234-238: LGTM!

The new input handling logic correctly short-circuits for scalars and numpy arrays. The assert at line 238 is reachable and serves as a guard to ensure only torch.Tensor inputs proceed to the conversion logic.

deepmd/pt/train/training.py (3)

275-278: LGTM!

The get_lr function correctly injects num_steps into the learning rate parameters before constructing the BaseLR instance, aligning with the new unified LR scheduler API.


698-702: LGTM!

The LambdaLR lambda correctly computes the multiplicative factor by dividing the absolute learning rate value by start_lr. This properly delegates warmup and decay logic to the BaseLR.value() method while maintaining compatibility with PyTorch's LambdaLR scheduler.

Also applies to: 722-726


797-799: LGTM!

The pref_lr simplification is correct since warmup handling is now encapsulated within the BaseLR.value() method. The scheduler's get_last_lr() already returns the warmup-adjusted learning rate.

deepmd/tf/utils/learning_rate.py (3)

30-36: LGTM!

The constructor correctly validates the presence of start_lr and stores the configuration for deferred schedule building.


68-100: LGTM!

The build() method correctly:

  1. Creates a parameter copy to avoid mutation
  2. Defaults to "exp" type for backward compatibility
  3. Uses tf.numpy_function (TF 2.x API) for runtime LR evaluation
  4. Preserves tensor shape information and casts to float32 for model compatibility

102-123: LGTM!

The value() method correctly guards against unbuilt state and properly delegates to the BaseLR.value() method with appropriate type conversion.

deepmd/tf/fit/dos.py (1)

651-668: LGTM!

The docstring correctly updates the lr parameter type from LearningRateExp to LearningRateSchedule, aligning with the renamed class. The method implementation using lr.start_lr() is compatible with the new API.

deepmd/tf/train/trainer.py (1)

429-434: Verify LR step indexing convention (value(stop_batch) vs value(stop_batch - 1)).

The log prints self.lr.value(stop_batch) as “final lr”. Depending on whether “num_steps” is treated as a count or a last-step index, the LR actually used on the last optimizer update is often value(stop_batch - 1). Please confirm this matches the new BaseLR convention and cross-backend tests.

deepmd/utils/argcheck.py (1)

2483-2547: Nice consolidation + constraints; ensure examples/docs are updated for required stop settings.

Requiring exactly one of stop_lr / stop_lr_rate and the warmup mutual exclusion makes sense, but it’s a user-facing breaking change (configs that omitted both will now hard-fail). Please ensure the example input.json files and docs are updated accordingly (matches the PR TODOs).

Optional: address Ruff TRY003 if CI enforces it
-        raise ValueError(
-            "stop_lr and stop_lr_rate are mutually exclusive. "
-            f"Got stop_lr={data['stop_lr']}, stop_lr_rate={data['stop_lr_rate']}"
-        )
+        raise ValueError(
+            f"stop_lr and stop_lr_rate are mutually exclusive: stop_lr={data['stop_lr']}, "
+            f"stop_lr_rate={data['stop_lr_rate']}"
+        )  # noqa: TRY003

Also applies to: 2549-2672, 2689-2715

deepmd/dpmodel/utils/learning_rate.py (1)

146-190: Request verification: “num_steps” and “final LR” off-by-one semantics across backends.

With decay_num_steps = num_steps - warmup_steps and cosine using step / decay_num_steps, stop_lr is reached at step == num_steps (after warmup). Many training loops run steps [0, num_steps-1], so they may never hit the exact stop LR. Please confirm this is intentional and consistent across TF/PT/PD integrations + the new consistency tests.

Also (optional), consider keeping step integral through exp’s step // decay_steps to avoid accidental non-integer step inputs being silently accepted.

Optional refactor: integer exponent in exp decay
     def _decay_value(self, step: int | Array) -> Array:
@@
-        step_lr = self.start_lr * xp.pow(
+        step_i = xp.astype(step, xp.int64)
+        exp_i = xp.astype(step_i // self.decay_steps, xp.float64)
+        step_lr = self.start_lr * xp.pow(
             xp.asarray(self.decay_rate, device=array_api_compat.device(step)),
-            xp.astype(step // self.decay_steps, xp.float64),
+            exp_i,
         )

Also applies to: 278-301, 363-397

deepmd/pd/train/training.py (1)

582-586: Verify: Adam scheduler path behavior for multi-task + learning_rate_dict, and LR end-step convention.

LambdaDecay references self.lr_exp.start_lr; if self.lr_exp is a dict (multi-task + learning_rate_dict), this will break unless that combo is intentionally unsupported—in which case, please fail fast with a clear error earlier.

Also please confirm the intended “final step” convention for value(step + start_step) vs BaseLR’s num_steps definition (to ensure stop LR is reached when expected).

source/tests/pd/model/test_model.py (5)

25-25: LGTM!

Import correctly updated to use LearningRateExp from deepmd.dpmodel.utils.learning_rate aliased as MyLRExp for the Paddle test.


51-53: LGTM!

Import updated to use the new LearningRateSchedule wrapper for TensorFlow.


111-111: LGTM!

Variable renamed from stop_steps to num_steps to align with the unified API.


140-140: LGTM!

Correctly passes self.num_steps to dp_lr.build() method.


228-236: LGTM!

The _get_dp_lr method correctly returns a LearningRateSchedule with a dict payload containing the required configuration parameters.

source/tests/pt/model/test_model.py (5)

34-34: LGTM!

Import correctly updated to use LearningRateExp from deepmd.pt.utils.learning_rate aliased as MyLRExp for the PyTorch test.


51-53: LGTM!

Import updated to use the new LearningRateSchedule wrapper for TensorFlow.


111-111: LGTM!

Variable renamed from stop_steps to num_steps to align with the unified API.


140-140: LGTM!

Correctly passes self.num_steps to dp_lr.build() method.


228-236: LGTM!

The _get_dp_lr method correctly returns a LearningRateSchedule with a dict payload containing the required configuration parameters.

source/tests/pt/test_lr.py (7)

13-15: LGTM!

Import updated to use the new LearningRateSchedule from TensorFlow utils.


22-24: LGTM!

Good addition of the comment explaining the constraint that decay_steps must not exceed num_steps. The variable naming change from stop_steps to num_steps aligns with the unified API.


35-42: LGTM!

Correctly migrated to dict-based payload for LearningRateSchedule construction.


48-53: LGTM!

Updated to use keyword arguments for LearningRateExp instantiation, improving clarity.


77-84: Good use of local variable to avoid side effects.

Using decay_step_for_rate as a local variable instead of modifying self.decay_step prevents unintended side effects in subsequent test iterations.


85-91: LGTM!

Properly uses keyword arguments and sets stop_lr=1e-10 as a minimal value for the decay rate override test.


121-138: LGTM!

The TestLearningRateCosine class properly tests the cosine annealing schedule with updated variable naming (num_steps) and keyword arguments.

source/tests/universal/dpmodel/utils/test_learning_rate.py (6)

1-12: LGTM!

Good file setup with appropriate imports for testing the unified learning rate schedules.


15-51: LGTM!

Comprehensive tests for basic exponential decay, including stop_lr_rate parameter and explicit decay_rate override functionality.


53-75: LGTM!

Good coverage of cosine annealing basics, including the midpoint value check at (start_lr + stop_lr) / 2.


77-145: LGTM!

Thorough warmup testing covering:

  • warmup_steps with both exponential and cosine decay
  • warmup_ratio parameter
  • warmup_start_factor parameter
  • Zero warmup edge case

147-179: LGTM!

Good tests for array input support, which is essential for JIT compatibility across backends.


181-202: LGTM!

Important tests verifying that learning rates clamp to stop_lr when steps exceed num_steps.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

- Add comprehensive documentation for warmup parameters (warmup_steps, warmup_ratio, warmup_start_factor)
- Document stop_lr_rate as an alternative to stop_lr for ratio-based specification
- Add complete documentation for cosine annealing scheduler
- Update theory section with separate warmup and decay phase formulas
- Provide 7 configuration examples covering both exp and cosine types with warmup variants
@github-actions github-actions bot added the Docs label Jan 15, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@doc/train/training-advanced.md`:
- Around line 86-93: The exponential decay formula lr(t) = start_lr * decay_rate
^ ( (t - warmup_steps) / decay_steps ) is missing that the division is integer
(floor) division, causing confusion about discrete updates; update the
documentation near lr(t) and the discussion of start_lr/stop_lr to either
replace “/ decay_steps” with an explicit floor notation (e.g., ⌊(t -
warmup_steps) / decay_steps⌋) or add a short note stating the division is floor
(integer) division so readers know the learning rate updates in discrete steps
rather than continuously, and reference the symbols lr(t), start_lr, decay_rate,
warmup_steps, and decay_steps when making the change.
🧹 Nitpick comments (3)
doc/train/training-advanced.md (3)

66-70: Replace hard tabs with spaces in JSON examples.

The JSON configuration examples contain hard tabs, which can cause rendering inconsistencies. Markdown files should use spaces for indentation.

♻️ Formatting recommendation

Replace all hard tabs with spaces (typically 4 spaces per indentation level) in all JSON code blocks throughout the examples section.

Also applies to: 120-125, 131-136, 144-151, 159-165, 175-179, 187-191, 197-201, 209-215


90-93: Add language identifiers to code blocks.

The code blocks on lines 90-93 and 100-103 lack language identifiers. Adding identifiers improves readability and enables proper syntax highlighting.

♻️ Suggested improvement
-```
+```text
 lr(t) = start_lr * decay_rate ^ ( (t - warmup_steps) / decay_steps )
 ```

Apply the same change to the cosine formula block.

Also applies to: 100-103


116-116: Use proper heading levels instead of bold text for examples.

Example titles (lines 116, 127, 140, 155, 183, 193, 205) use bold text instead of proper Markdown headings. Using heading levels (e.g., ####) improves document structure and navigation.

♻️ Suggested improvement
-**Example 1: Basic exponential decay without warmup**
+#### Example 1: Basic exponential decay without warmup

Apply the same change to all seven example titles.

Also applies to: 127-127, 140-140, 155-155, 183-183, 193-193, 205-205

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 84f3953 and 189ed4c.

📒 Files selected for processing (1)
  • doc/train/training-advanced.md
🧰 Additional context used
🪛 LanguageTool
doc/train/training-advanced.md

[grammar] ~58-~58: Ensure spelling is correct
Context: ..."cosine"`). Both types support optional warmup and can use either absolute stopping le...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

🪛 markdownlint-cli2 (0.18.1)
doc/train/training-advanced.md

66-66: Hard tabs
Column: 1

(MD010, no-hard-tabs)


66-66: Hard tabs
Column: 9

(MD010, no-hard-tabs)


67-67: Hard tabs
Column: 1

(MD010, no-hard-tabs)


67-67: Hard tabs
Column: 13

(MD010, no-hard-tabs)


68-68: Hard tabs
Column: 1

(MD010, no-hard-tabs)


68-68: Hard tabs
Column: 12

(MD010, no-hard-tabs)


69-69: Hard tabs
Column: 1

(MD010, no-hard-tabs)


69-69: Hard tabs
Column: 16

(MD010, no-hard-tabs)


70-70: Hard tabs
Column: 1

(MD010, no-hard-tabs)


70-70: Hard tabs
Column: 13

(MD010, no-hard-tabs)


90-90: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


100-100: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


116-116: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


120-120: Hard tabs
Column: 1

(MD010, no-hard-tabs)


120-120: Hard tabs
Column: 9

(MD010, no-hard-tabs)


121-121: Hard tabs
Column: 1

(MD010, no-hard-tabs)


121-121: Hard tabs
Column: 13

(MD010, no-hard-tabs)


122-122: Hard tabs
Column: 1

(MD010, no-hard-tabs)


122-122: Hard tabs
Column: 12

(MD010, no-hard-tabs)


123-123: Hard tabs
Column: 1

(MD010, no-hard-tabs)


123-123: Hard tabs
Column: 16

(MD010, no-hard-tabs)


127-127: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


131-131: Hard tabs
Column: 1

(MD010, no-hard-tabs)


131-131: Hard tabs
Column: 9

(MD010, no-hard-tabs)


132-132: Hard tabs
Column: 1

(MD010, no-hard-tabs)


132-132: Hard tabs
Column: 13

(MD010, no-hard-tabs)


133-133: Hard tabs
Column: 1

(MD010, no-hard-tabs)


133-133: Hard tabs
Column: 17

(MD010, no-hard-tabs)


134-134: Hard tabs
Column: 1

(MD010, no-hard-tabs)


134-134: Hard tabs
Column: 16

(MD010, no-hard-tabs)


140-140: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


144-144: Hard tabs
Column: 1

(MD010, no-hard-tabs)


144-144: Hard tabs
Column: 9

(MD010, no-hard-tabs)


145-145: Hard tabs
Column: 1

(MD010, no-hard-tabs)


145-145: Hard tabs
Column: 13

(MD010, no-hard-tabs)


146-146: Hard tabs
Column: 1

(MD010, no-hard-tabs)


146-146: Hard tabs
Column: 12

(MD010, no-hard-tabs)


147-147: Hard tabs
Column: 1

(MD010, no-hard-tabs)


147-147: Hard tabs
Column: 16

(MD010, no-hard-tabs)


148-148: Hard tabs
Column: 1

(MD010, no-hard-tabs)


148-148: Hard tabs
Column: 17

(MD010, no-hard-tabs)


149-149: Hard tabs
Column: 1

(MD010, no-hard-tabs)


155-155: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


159-159: Hard tabs
Column: 1

(MD010, no-hard-tabs)


159-159: Hard tabs
Column: 9

(MD010, no-hard-tabs)


160-160: Hard tabs
Column: 1

(MD010, no-hard-tabs)


160-160: Hard tabs
Column: 13

(MD010, no-hard-tabs)


161-161: Hard tabs
Column: 1

(MD010, no-hard-tabs)


161-161: Hard tabs
Column: 17

(MD010, no-hard-tabs)


162-162: Hard tabs
Column: 1

(MD010, no-hard-tabs)


162-162: Hard tabs
Column: 16

(MD010, no-hard-tabs)


163-163: Hard tabs
Column: 1

(MD010, no-hard-tabs)


163-163: Hard tabs
Column: 17

(MD010, no-hard-tabs)


175-175: Hard tabs
Column: 1

(MD010, no-hard-tabs)


175-175: Hard tabs
Column: 9

(MD010, no-hard-tabs)


176-176: Hard tabs
Column: 1

(MD010, no-hard-tabs)


176-176: Hard tabs
Column: 13

(MD010, no-hard-tabs)


177-177: Hard tabs
Column: 1

(MD010, no-hard-tabs)


177-177: Hard tabs
Column: 12

(MD010, no-hard-tabs)


183-183: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


187-187: Hard tabs
Column: 1

(MD010, no-hard-tabs)


187-187: Hard tabs
Column: 9

(MD010, no-hard-tabs)


188-188: Hard tabs
Column: 1

(MD010, no-hard-tabs)


188-188: Hard tabs
Column: 13

(MD010, no-hard-tabs)


189-189: Hard tabs
Column: 1

(MD010, no-hard-tabs)


189-189: Hard tabs
Column: 12

(MD010, no-hard-tabs)


193-193: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


197-197: Hard tabs
Column: 1

(MD010, no-hard-tabs)


197-197: Hard tabs
Column: 9

(MD010, no-hard-tabs)


198-198: Hard tabs
Column: 1

(MD010, no-hard-tabs)


198-198: Hard tabs
Column: 13

(MD010, no-hard-tabs)


199-199: Hard tabs
Column: 1

(MD010, no-hard-tabs)


199-199: Hard tabs
Column: 17

(MD010, no-hard-tabs)


205-205: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


209-209: Hard tabs
Column: 1

(MD010, no-hard-tabs)


209-209: Hard tabs
Column: 9

(MD010, no-hard-tabs)


210-210: Hard tabs
Column: 1

(MD010, no-hard-tabs)


210-210: Hard tabs
Column: 13

(MD010, no-hard-tabs)


211-211: Hard tabs
Column: 1

(MD010, no-hard-tabs)


211-211: Hard tabs
Column: 12

(MD010, no-hard-tabs)


212-212: Hard tabs
Column: 1

(MD010, no-hard-tabs)


212-212: Hard tabs
Column: 17

(MD010, no-hard-tabs)


213-213: Hard tabs
Column: 1

(MD010, no-hard-tabs)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (1)
doc/train/training-advanced.md (1)

78-81: All documented parameter constraints are properly enforced in the implementation.

The verification confirms that:

  • stop_lr and stop_lr_rate are mutually exclusive, with validation at lines 74-78 of deepmd/dpmodel/utils/learning_rate.py
  • Either stop_lr or stop_lr_rate must be provided, enforced at lines 79-83
  • warmup_steps and warmup_ratio are mutually exclusive, validated at lines 95-99
  • warmup_start_factor defaults to 0.0 as documented, with the default defined at line 39

All constraints mentioned in the documentation match the actual implementation behavior. No changes needed.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

…ay mode

- Rename parameter `stop_lr_rate` to `stop_lr_ratio` for clearer semantics
- Add `smooth` parameter to LearningRateExp for continuous vs stepped decay
- Add comprehensive mathematical formulas to learning rate docstrings
- Fix dtype handling in pt/pd to_numpy_array (respect GLOBAL_NP_FLOAT_PRECISION)
- Remove redundant validation and caching in TF LearningRateSchedule wrapper
- Update documentation and examples with new parameter names and smooth mode

Breaking change: `stop_lr_rate` is now `stop_lr_ratio` in learning rate config.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Fix all issues with AI agents
In `@deepmd/pt/utils/utils.py`:
- Around line 237-246: The precision lookup uses reverse_precision_dict (built
from PT_PRECISION_DICT) against xx.dtype which fails for numpy arrays; move the
ndarray branch before the reverse lookup or add a separate numpy-dtype lookup
using NP_PRECISION_DICT. Concretely, check isinstance(xx, np.ndarray) first and
return xx.astype(NP_PRECISION_DICT.get(xx.dtype, desired_default)) or, if you
want to preserve original numpy precision, simply return xx as-is; otherwise
proceed with reverse_precision_dict -> prec -> NP_PRECISION_DICT lookup and the
torch.Tensor assertion.

In `@deepmd/tf/utils/learning_rate.py`:
- Around line 45-63: Ruff flags TRY003 for the RuntimeError message in the
base_lr property; either silence it by appending "# noqa: TRY003" to the raise
line or replace the RuntimeError with a small custom exception (e.g., class
LearningRateNotBuiltError(RuntimeError): pass) and raise
LearningRateNotBuiltError("Learning rate schedule is not built yet."). Update
the Raises section in the docstring accordingly and apply the same change to the
other occurrence covering lines 98-120 (same pattern referencing self._base_lr
and the raise).
- Around line 64-97: The code in build() hardcodes float64 for the numpy dtype
and tf.numpy_function Tout, which conflicts with the project's
GLOBAL_TF_FLOAT_PRECISION; update the _lr_value numpy conversion and the
tf.numpy_function Tout to use GLOBAL_TF_FLOAT_PRECISION (map to the matching
numpy dtype for np.asarray and to the TensorFlow dtype for Tout) so
BaseLR.value(step) is converted using that precision; ensure
GLOBAL_TF_FLOAT_PRECISION is imported/available in this module and that the
returned tf.Tensor uses that dtype consistently (you may need to cast the
tf.numpy_function result to GLOBAL_TF_FLOAT_PRECISION if required).

In `@deepmd/utils/argcheck.py`:
- Around line 2549-2581: The validation in _check_decay_steps_args incorrectly
rejects typical decay rates (e.g., 0.95); change the decay_rate check to allow
values in the (0, 1] range by replacing the condition with something like: if
decay_rate is not None and (decay_rate <= 0 or decay_rate > 1): raise
ValueError(f"decay_rate ({decay_rate}) must be > 0 and <= 1."), keeping the
lr_type and decay_steps logic unchanged.

In `@doc/train/training-advanced.md`:
- Around line 64-238: Replace tab characters with spaces in all JSON code blocks
(e.g., the "learning_rate" examples where keys like "learning_rate", "type",
"start_lr", "decay_steps" are indented with tabs), add a fenced-code language
label to every triple-backtick block (use ```json for JSON examples or ```text
for plain text), and convert any bold-as-heading lines (e.g., "**Basic
parameters**", "**Additional parameters for `exp` type only:**", "**Learning
rate formula for `exp` type:**", etc.) into proper Markdown headings (e.g.,
prepend appropriate `#/`##/###) so lint rules MD010, MD040 and MD036 are
satisfied.
- Around line 105-114: The doc reuses the name decay_steps for the cosine
denominator which conflicts with the exponential schedule's decay_steps; update
the cosine section to compute and use a distinct name (e.g., decay_phase_length
or decay_length) defined as decay_phase_length = numb_steps - warmup_steps and
use it in the cosine formula lr(t) = stop_lr + (start_lr - stop_lr) / 2 * (1 +
cos(pi * (t - warmup_steps) / decay_phase_length)); ensure references to
decay_steps in the cosine paragraph are renamed to this new symbol to avoid
confusion with the exp schedule.
♻️ Duplicate comments (1)
deepmd/dpmodel/utils/learning_rate.py (1)

31-123: Add explicit start_lr > 0 validation to avoid ZeroDivision/log-domain failures.
Even if “start_lr cannot be 0” by convention, current code will crash (e.g., self.stop_lr / self.start_lr in cosine, and log(clamped_stop_lr / self.start_lr) in exp) with a non-obvious stack trace if config is wrong.

Also applies to: 441-442

🧹 Nitpick comments (5)
deepmd/pt/utils/utils.py (1)

222-227: Consider adding overloads for new input types.

The function now accepts np.ndarray and float inputs, but the overloads only cover torch.Tensor and None. Type checkers may flag calls with these new types as invalid.

♻️ Suggested overloads
 `@overload`
 def to_numpy_array(xx: torch.Tensor) -> np.ndarray: ...


 `@overload`
 def to_numpy_array(xx: None) -> None: ...


+@overload
+def to_numpy_array(xx: np.ndarray) -> np.ndarray: ...
+
+
+@overload
+def to_numpy_array(xx: float) -> np.ndarray: ...
+
+
 def to_numpy_array(
     xx: torch.Tensor | np.ndarray | float | None,
 ) -> np.ndarray | None:
deepmd/utils/argcheck.py (1)

2483-2547: Stop-LR and warmup mutual-exclusion checks are clear and match the new config semantics.

deepmd/dpmodel/utils/learning_rate.py (1)

333-361: xp.clip(..., None) may not be Array-API-strict; prefer one-sided clamp via xp.maximum.
If you truly need array_api_strict compatibility, clip’s signature can be a gotcha across namespaces. A one-sided lower clamp avoids relying on None support.

Proposed change
@@
-        # Clip to min_lr for numerical stability in JIT
-        step_lr = xp.clip(step_lr, self.min_lr, None)
+        # Clamp to min_lr for numerical stability in JIT
+        step_lr = xp.maximum(step_lr, self.min_lr)
         return step_lr
source/tests/universal/dpmodel/utils/test_learning_rate.py (2)

90-90: Use atol instead of rtol when comparing to zero.

When desired=0.0, the tolerance formula |actual - desired| <= atol + rtol * |desired| reduces to |actual| <= 0 with default atol=0. This requires exact equality. Use atol for near-zero comparisons to avoid fragile tests.

Same issue applies to lines 104, 162, and 176.

♻️ Suggested fix
-        np.testing.assert_allclose(lr.value(0), 0.0, rtol=1e-10)
+        np.testing.assert_allclose(lr.value(0), 0.0, atol=1e-10)

Apply similar changes to lines 104, 162, and 176.


204-240: Consider adding tests for mutual-exclusion validation.

Per the PR objectives, the implementation adds "mutual-exclusion validation for warmup_steps vs warmup_ratio" and "mutual-exclusion validation for stop_lr vs stop_ratio". These validation behaviors should be tested to ensure they raise appropriate errors when both parameters are specified.

♻️ Suggested additional tests
def test_stop_lr_and_stop_lr_ratio_mutual_exclusion(self) -> None:
    """Test that specifying both stop_lr and stop_lr_ratio raises ValueError."""
    with self.assertRaises(ValueError):
        LearningRateExp(
            start_lr=1e-3,
            stop_lr=1e-5,
            stop_lr_ratio=0.01,
            num_steps=10000,
            decay_steps=5000,
        )

def test_warmup_steps_and_warmup_ratio_mutual_exclusion(self) -> None:
    """Test that specifying both warmup_steps and warmup_ratio raises ValueError."""
    with self.assertRaises(ValueError):
        LearningRateExp(
            start_lr=1e-3,
            stop_lr=1e-5,
            num_steps=10000,
            decay_steps=5000,
            warmup_steps=1000,
            warmup_ratio=0.1,
        )
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 189ed4c and 2e11654.

📒 Files selected for processing (8)
  • deepmd/dpmodel/utils/learning_rate.py
  • deepmd/pd/utils/utils.py
  • deepmd/pt/utils/utils.py
  • deepmd/tf/train/trainer.py
  • deepmd/tf/utils/learning_rate.py
  • deepmd/utils/argcheck.py
  • doc/train/training-advanced.md
  • source/tests/universal/dpmodel/utils/test_learning_rate.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pd/utils/utils.py
🧰 Additional context used
🧬 Code graph analysis (4)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
  • BaseLR (25-189)
  • value (146-189)
deepmd/tf/train/trainer.py (1)
deepmd/tf/utils/learning_rate.py (3)
  • LearningRateSchedule (20-119)
  • start_lr (34-43)
  • build (64-96)
deepmd/pt/utils/utils.py (1)
deepmd/pt/model/network/network.py (1)
  • Tensor (34-35)
deepmd/dpmodel/utils/learning_rate.py (3)
deepmd/tf/utils/learning_rate.py (2)
  • start_lr (34-43)
  • value (98-119)
deepmd/pd/train/training.py (1)
  • step (722-963)
deepmd/pt/train/training.py (1)
  • step (773-1142)
🪛 LanguageTool
doc/train/training-advanced.md

[grammar] ~58-~58: Ensure spelling is correct
Context: ..."cosine"`). Both types support optional warmup and can use either absolute stopping le...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

🪛 markdownlint-cli2 (0.18.1)
doc/train/training-advanced.md

66-66: Hard tabs
Column: 1

(MD010, no-hard-tabs)


66-66: Hard tabs
Column: 9

(MD010, no-hard-tabs)


67-67: Hard tabs
Column: 1

(MD010, no-hard-tabs)


67-67: Hard tabs
Column: 13

(MD010, no-hard-tabs)


68-68: Hard tabs
Column: 1

(MD010, no-hard-tabs)


68-68: Hard tabs
Column: 12

(MD010, no-hard-tabs)


69-69: Hard tabs
Column: 1

(MD010, no-hard-tabs)


69-69: Hard tabs
Column: 16

(MD010, no-hard-tabs)


70-70: Hard tabs
Column: 1

(MD010, no-hard-tabs)


70-70: Hard tabs
Column: 13

(MD010, no-hard-tabs)


93-93: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


99-99: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


109-109: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


125-125: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


129-129: Hard tabs
Column: 1

(MD010, no-hard-tabs)


129-129: Hard tabs
Column: 9

(MD010, no-hard-tabs)


130-130: Hard tabs
Column: 1

(MD010, no-hard-tabs)


130-130: Hard tabs
Column: 13

(MD010, no-hard-tabs)


131-131: Hard tabs
Column: 1

(MD010, no-hard-tabs)


131-131: Hard tabs
Column: 12

(MD010, no-hard-tabs)


132-132: Hard tabs
Column: 1

(MD010, no-hard-tabs)


132-132: Hard tabs
Column: 16

(MD010, no-hard-tabs)


136-136: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


140-140: Hard tabs
Column: 1

(MD010, no-hard-tabs)


140-140: Hard tabs
Column: 9

(MD010, no-hard-tabs)


141-141: Hard tabs
Column: 1

(MD010, no-hard-tabs)


141-141: Hard tabs
Column: 13

(MD010, no-hard-tabs)


142-142: Hard tabs
Column: 1

(MD010, no-hard-tabs)


142-142: Hard tabs
Column: 18

(MD010, no-hard-tabs)


143-143: Hard tabs
Column: 1

(MD010, no-hard-tabs)


143-143: Hard tabs
Column: 16

(MD010, no-hard-tabs)


149-149: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


153-153: Hard tabs
Column: 1

(MD010, no-hard-tabs)


153-153: Hard tabs
Column: 9

(MD010, no-hard-tabs)


154-154: Hard tabs
Column: 1

(MD010, no-hard-tabs)


154-154: Hard tabs
Column: 13

(MD010, no-hard-tabs)


155-155: Hard tabs
Column: 1

(MD010, no-hard-tabs)


155-155: Hard tabs
Column: 12

(MD010, no-hard-tabs)


156-156: Hard tabs
Column: 1

(MD010, no-hard-tabs)


156-156: Hard tabs
Column: 16

(MD010, no-hard-tabs)


157-157: Hard tabs
Column: 1

(MD010, no-hard-tabs)


157-157: Hard tabs
Column: 17

(MD010, no-hard-tabs)


158-158: Hard tabs
Column: 1

(MD010, no-hard-tabs)


164-164: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


168-168: Hard tabs
Column: 1

(MD010, no-hard-tabs)


168-168: Hard tabs
Column: 9

(MD010, no-hard-tabs)


169-169: Hard tabs
Column: 1

(MD010, no-hard-tabs)


169-169: Hard tabs
Column: 13

(MD010, no-hard-tabs)


170-170: Hard tabs
Column: 1

(MD010, no-hard-tabs)


170-170: Hard tabs
Column: 18

(MD010, no-hard-tabs)


171-171: Hard tabs
Column: 1

(MD010, no-hard-tabs)


171-171: Hard tabs
Column: 16

(MD010, no-hard-tabs)


172-172: Hard tabs
Column: 1

(MD010, no-hard-tabs)


172-172: Hard tabs
Column: 17

(MD010, no-hard-tabs)


184-184: Hard tabs
Column: 1

(MD010, no-hard-tabs)


184-184: Hard tabs
Column: 9

(MD010, no-hard-tabs)


185-185: Hard tabs
Column: 1

(MD010, no-hard-tabs)


185-185: Hard tabs
Column: 13

(MD010, no-hard-tabs)


186-186: Hard tabs
Column: 1

(MD010, no-hard-tabs)


186-186: Hard tabs
Column: 12

(MD010, no-hard-tabs)


192-192: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


196-196: Hard tabs
Column: 1

(MD010, no-hard-tabs)


196-196: Hard tabs
Column: 9

(MD010, no-hard-tabs)


197-197: Hard tabs
Column: 1

(MD010, no-hard-tabs)


197-197: Hard tabs
Column: 13

(MD010, no-hard-tabs)


198-198: Hard tabs
Column: 1

(MD010, no-hard-tabs)


198-198: Hard tabs
Column: 12

(MD010, no-hard-tabs)


202-202: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


206-206: Hard tabs
Column: 1

(MD010, no-hard-tabs)


206-206: Hard tabs
Column: 9

(MD010, no-hard-tabs)


207-207: Hard tabs
Column: 1

(MD010, no-hard-tabs)


207-207: Hard tabs
Column: 13

(MD010, no-hard-tabs)


208-208: Hard tabs
Column: 1

(MD010, no-hard-tabs)


208-208: Hard tabs
Column: 18

(MD010, no-hard-tabs)


214-214: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


218-218: Hard tabs
Column: 1

(MD010, no-hard-tabs)


218-218: Hard tabs
Column: 9

(MD010, no-hard-tabs)


219-219: Hard tabs
Column: 1

(MD010, no-hard-tabs)


219-219: Hard tabs
Column: 13

(MD010, no-hard-tabs)


220-220: Hard tabs
Column: 1

(MD010, no-hard-tabs)


220-220: Hard tabs
Column: 12

(MD010, no-hard-tabs)


221-221: Hard tabs
Column: 1

(MD010, no-hard-tabs)


221-221: Hard tabs
Column: 17

(MD010, no-hard-tabs)


222-222: Hard tabs
Column: 1

(MD010, no-hard-tabs)


228-228: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)


232-232: Hard tabs
Column: 1

(MD010, no-hard-tabs)


232-232: Hard tabs
Column: 9

(MD010, no-hard-tabs)


233-233: Hard tabs
Column: 1

(MD010, no-hard-tabs)


233-233: Hard tabs
Column: 13

(MD010, no-hard-tabs)


234-234: Hard tabs
Column: 1

(MD010, no-hard-tabs)


234-234: Hard tabs
Column: 12

(MD010, no-hard-tabs)


235-235: Hard tabs
Column: 1

(MD010, no-hard-tabs)


235-235: Hard tabs
Column: 16

(MD010, no-hard-tabs)


236-236: Hard tabs
Column: 1

(MD010, no-hard-tabs)


236-236: Hard tabs
Column: 11

(MD010, no-hard-tabs)

🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py

61-61: Avoid specifying long messages outside the exception class

(TRY003)


118-118: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pt/utils/utils.py

243-243: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/dpmodel/utils/learning_rate.py

40-40: Unused method argument: kwargs

(ARG002)


75-78: Avoid specifying long messages outside the exception class

(TRY003)


80-83: Avoid specifying long messages outside the exception class

(TRY003)


96-99: Avoid specifying long messages outside the exception class

(TRY003)


109-109: Avoid specifying long messages outside the exception class

(TRY003)


111-111: Avoid specifying long messages outside the exception class

(TRY003)


113-113: Avoid specifying long messages outside the exception class

(TRY003)


309-309: Avoid specifying long messages outside the exception class

(TRY003)


312-315: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/utils/argcheck.py

2506-2509: Avoid specifying long messages outside the exception class

(TRY003)


2511-2514: Avoid specifying long messages outside the exception class

(TRY003)


2542-2545: Avoid specifying long messages outside the exception class

(TRY003)


2575-2575: Avoid specifying long messages outside the exception class

(TRY003)


2579-2579: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
🔇 Additional comments (13)
deepmd/pt/utils/utils.py (2)

19-19: LGTM!

Import is correctly added to support the new scalar conversion logic.


235-236: LGTM!

Scalar handling correctly wraps numeric inputs into numpy arrays with the configured global precision.

deepmd/tf/utils/learning_rate.py (1)

20-44: Clean wrapper API (start_lr() + params copy) and clear intent.

deepmd/tf/train/trainer.py (1)

106-118: get_lr_and_coef() update is straightforward and keeps scale_by_worker out of LR params.

doc/train/training-advanced.md (1)

9-52: Theory section reads consistent with the new warmup + (exp/cosine) design, including explicit floor() for stepped exp.

deepmd/utils/argcheck.py (1)

2583-2756: Good consolidation: shared args + centralized extra_check keeps LR variants consistent.

deepmd/dpmodel/utils/learning_rate.py (2)

146-189: Warmup/decay split in BaseLR.value() is readable and matches the documented piecewise schedule.


443-477: Cosine decay clamps correctly past the decay phase end (prevents cosine “bounce back”).

source/tests/universal/dpmodel/utils/test_learning_rate.py (5)

1-12: LGTM!

Imports are appropriate for the test module.


15-50: LGTM!

Good test coverage for basic exponential decay functionality, including stop_lr_ratio and decay_rate override.


53-74: LGTM!

Good test coverage for cosine annealing, including the mathematically correct midpoint check.


147-178: LGTM!

Good test coverage for array/batch inputs, which is important for JIT compatibility. Shape verification is appropriate.


181-201: LGTM!

Good edge-case coverage for behavior beyond num_steps, verifying the clamping to stop_lr.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

When decay_steps exceeds the decay phase (num_steps - warmup_steps) and
decay_rate is not explicitly provided, automatically adjust decay_steps to
a sensible default (capped at 100, or decay_total//100 + 1) instead of
raising ValueError.

This makes the learning rate scheduler more user-friendly by gracefully
handling misconfigured decay_steps values.

Changes:
- LearningRateExp: auto-adjust decay_steps when >= decay_total
- Update argcheck and training-advanced.md documentation
- Update pd/pt/tf test_lr.py to use auto-adjusted decay_steps
- Remove obsolete validation tests in test_learning_rate.py
- Fix tf test dtype: float32 -> float64
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 36 out of 36 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@deepmd/tf/utils/learning_rate.py`:
- Around line 92-107: The closure _lr_value currently references self._base_lr
at evaluation time, which can break if build() is called multiple times; capture
the BaseLR instance into a local variable (e.g., base_lr = self._base_lr) inside
LearningRateSchedule.build() before defining _lr_value, then have _lr_value use
that local base_lr instead of self._base_lr so each returned tensor is bound to
the BaseLR instance present at closure creation.
♻️ Duplicate comments (2)
deepmd/tf/utils/learning_rate.py (2)

49-66: TRY003 lint warning on RuntimeError message.

This was already flagged in a previous review. Consider adding # noqa: TRY003 or defining a custom exception class if CI enforces this rule.


109-130: TRY003 lint warning on RuntimeError message (same as base_lr property).

Already addressed in past review. Apply the same fix (noqa comment or custom exception) here as well.

🧹 Nitpick comments (2)
deepmd/pt/utils/utils.py (1)

230-242: Add int to the type annotations for consistency.

The implementation handles int at line 239 (isinstance(xx, (float, int))), but the overloads and main signature don't include it. This creates a type annotation mismatch.

♻️ Proposed fix to include int in type annotations
 `@overload`
-def to_numpy_array(xx: float) -> np.ndarray: ...
+def to_numpy_array(xx: float | int) -> np.ndarray: ...


 def to_numpy_array(
-    xx: torch.Tensor | np.ndarray | float | None,
+    xx: torch.Tensor | np.ndarray | float | int | None,
 ) -> np.ndarray | None:
deepmd/tf/utils/learning_rate.py (1)

38-47: Consider making start_lr a property for API consistency.

base_lr is defined as a property, but start_lr is a method. This inconsistency may confuse users of the API.

♻️ Suggested refactor
-    def start_lr(self) -> float:
+    `@property`
+    def start_lr(self) -> float:
         """
         Get the starting learning rate.

         Returns
         -------
         float
             The starting learning rate.
         """
         return float(self._params["start_lr"])

Please verify that callers use start_lr as a method (with parentheses) or update them if changing to a property:

#!/bin/bash
# Search for usages of start_lr to verify call patterns
rg -n "\.start_lr\(" --type py
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aca34df and e5f2ef5.

📒 Files selected for processing (4)
  • deepmd/pt/utils/utils.py
  • deepmd/tf/train/trainer.py
  • deepmd/tf/utils/learning_rate.py
  • source/tests/universal/dpmodel/utils/test_learning_rate.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • source/tests/universal/dpmodel/utils/test_learning_rate.py
  • deepmd/tf/train/trainer.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/tf/utils/learning_rate.py
  • deepmd/pt/utils/utils.py
📚 Learning: 2024-10-26T02:09:01.365Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4258
File: deepmd/jax/utils/neighbor_stat.py:98-101
Timestamp: 2024-10-26T02:09:01.365Z
Learning: The function `to_jax_array` in `deepmd/jax/common.py` can handle `None` values, so it's safe to pass `None` to it without additional checks.

Applied to files:

  • deepmd/pt/utils/utils.py
🧬 Code graph analysis (2)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
  • BaseLR (25-189)
  • value (146-189)
deepmd/pt/utils/utils.py (3)
deepmd/pd/utils/utils.py (3)
  • to_numpy_array (231-231)
  • to_numpy_array (235-235)
  • to_numpy_array (238-256)
deepmd/dpmodel/common.py (1)
  • to_numpy_array (106-128)
deepmd/pt/model/network/network.py (1)
  • Tensor (34-35)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py

65-65: Avoid specifying long messages outside the exception class

(TRY003)


129-129: Avoid specifying long messages outside the exception class

(TRY003)

deepmd/pt/utils/utils.py

249-249: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (3)
deepmd/pt/utils/utils.py (1)

241-254: LGTM! The numpy array handling is correctly positioned.

The np.ndarray check now comes before the precision lookup (lines 241-242), which correctly handles numpy arrays by converting to GLOBAL_NP_FLOAT_PRECISION without attempting to use reverse_precision_dict (which contains PyTorch dtypes). The torch.Tensor path (lines 243-254) correctly maps through PT_PRECISION_DICTNP_PRECISION_DICT and handles bfloat16 conversion.

deepmd/tf/utils/learning_rate.py (2)

34-36: LGTM!

The defensive shallow copy via dict(params) prevents mutation of the caller's dictionary, and lazy initialization of _base_lr is appropriate for the two-phase initialization pattern.


133-135: LGTM!

The __all__ export correctly exposes only the public LearningRateSchedule class.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

@codecov
Copy link

codecov bot commented Jan 15, 2026

Codecov Report

❌ Patch coverage is 87.42515% with 21 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.93%. Comparing base (2a9667e) to head (ad40b00).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/dpmodel/utils/learning_rate.py 84.61% 10 Missing ⚠️
deepmd/utils/argcheck.py 87.23% 6 Missing ⚠️
deepmd/pt/utils/utils.py 60.00% 2 Missing ⚠️
deepmd/pd/utils/utils.py 50.00% 1 Missing ⚠️
deepmd/pt/train/training.py 66.66% 1 Missing ⚠️
deepmd/tf/utils/learning_rate.py 96.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5154      +/-   ##
==========================================
- Coverage   81.94%   81.93%   -0.01%     
==========================================
  Files         713      714       +1     
  Lines       73010    73380     +370     
  Branches     3616     3617       +1     
==========================================
+ Hits        59825    60125     +300     
- Misses      12022    12093      +71     
+ Partials     1163     1162       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@deepmd/dpmodel/utils/learning_rate.py`:
- Around line 322-329: The decay_rate calculation can divide by zero when
num_steps (hence decay_total) is 0; update the block that sets self.decay_rate
(the else branch using np.exp(np.log(clamped_stop_lr / self.start_lr) /
(decay_total / self.decay_steps)).item()) to guard against decay_total <= 0 or
self.decay_steps == 0: when either is zero, set self.decay_rate to 1.0 (no
decay) instead of performing the log/division, otherwise keep the existing
calculation using clamped_stop_lr, self.start_lr, decay_total and
self.decay_steps.
🧹 Nitpick comments (1)
deepmd/dpmodel/utils/learning_rate.py (1)

91-99: Clarify: warmup_steps=0 is treated as "not provided".

The check has_warmup_steps = warmup_steps != 0 means explicitly setting warmup_steps=0 while also providing warmup_ratio will silently use warmup_ratio instead of raising a mutual exclusivity error.

This is likely intentional since warmup_steps=0 is the default, but it differs from the stop_lr/stop_lr_ratio check which uses is not None. Consider adding a brief comment clarifying this design choice for maintainability.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e5f2ef5 and 9bfcf24.

📒 Files selected for processing (3)
  • deepmd/dpmodel/utils/learning_rate.py
  • deepmd/tf/loss/dos.py
  • deepmd/tf/loss/ener.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: 1azyking
Repo: deepmodeling/deepmd-kit PR: 4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In `deepmd/pt/loss/ener_hess.py`, the `label` uses the key `"atom_ener"` intentionally to maintain consistency with the forked version.

Applied to files:

  • deepmd/tf/loss/dos.py
  • deepmd/tf/loss/ener.py
📚 Learning: 2024-10-05T03:06:02.372Z
Learnt from: 1azyking
Repo: deepmodeling/deepmd-kit PR: 4169
File: deepmd/utils/argcheck.py:1982-2117
Timestamp: 2024-10-05T03:06:02.372Z
Learning: The `loss_ener_hess` and `loss_ener` functions should remain separate to avoid confusion, despite code duplication.

Applied to files:

  • deepmd/tf/loss/ener.py
📚 Learning: 2026-01-10T04:29:25.299Z
Learnt from: OutisLi
Repo: deepmodeling/deepmd-kit PR: 5130
File: deepmd/pt/optimizer/adamuon.py:44-113
Timestamp: 2026-01-10T04:29:25.299Z
Learning: The epsilon parameter (eps=1e-8) in the `zeropower_via_newtonschulz5` function in `deepmd/pt/optimizer/adamuon.py` follows the standard PyTorch AdaMuon reference implementations (which typically use eps=1e-7) and should not be modified for consistency with reference implementations, despite being smaller than bfloat16's machine epsilon.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (2)
deepmd/tf/loss/dos.py (1)
deepmd/tf/env.py (1)
  • global_cvt_2_tf_float (409-422)
deepmd/tf/loss/ener.py (1)
deepmd/tf/env.py (2)
  • global_cvt_2_ener_float (425-438)
  • global_cvt_2_tf_float (409-422)
🪛 Ruff (0.14.11)
deepmd/dpmodel/utils/learning_rate.py

40-40: Unused method argument: kwargs

(ARG002)


75-78: Avoid specifying long messages outside the exception class

(TRY003)


80-83: Avoid specifying long messages outside the exception class

(TRY003)


96-99: Avoid specifying long messages outside the exception class

(TRY003)


109-109: Avoid specifying long messages outside the exception class

(TRY003)


111-111: Avoid specifying long messages outside the exception class

(TRY003)


113-113: Avoid specifying long messages outside the exception class

(TRY003)


115-115: Avoid specifying long messages outside the exception class

(TRY003)


310-310: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (10)
deepmd/tf/loss/dos.py (1)

68-69: LGTM! Conversion wrappers ensure consistent tensor typing.

The addition of global_cvt_2_tf_float for find_dos and find_atom_dos ensures these indicator values are properly cast to the global TensorFlow precision before being used in the preference coefficient calculations. This aligns with the consistent pattern applied across other loss modules (e.g., find_force, find_virial in ener.py).

deepmd/tf/loss/ener.py (3)

152-159: LGTM! Appropriate precision casting for each indicator field.

The conversion function choices are correct:

  • find_energyglobal_cvt_2_ener_float (matches pref_e's energy precision)
  • find_force, find_virial, find_atom_ener, find_atom_pref, find_drdqglobal_cvt_2_tf_float (matches their respective prefactors' precision)

This ensures type consistency in the arithmetic operations that follow.


592-596: LGTM! Consistent with EnerStdLoss pattern.

The conversion function assignments mirror the pattern in EnerStdLoss: find_energy uses energy precision while force/virial/atom_ener indicators use TF float precision. This maintains consistency across the loss module implementations.


935-936: LGTM! Energy-related indicators use energy precision.

Both find_energy and find_ener_dipole are energy-related indicators and appropriately use global_cvt_2_ener_float. The outer wrappers on pref_e and pref_ed (lines 951 and 960) will handle the final precision casting for the preference coefficients.

deepmd/dpmodel/utils/learning_rate.py (6)

1-29: LGTM!

The imports and class registration pattern using the plugin registry are well-structured and consistent with the codebase architecture.


126-146: LGTM!

The abstract _decay_value method defines a clean contract for subclasses, with clear documentation about the step being relative to the end of warmup.


148-191: LGTM!

The value() method correctly handles:

  • Scalar vs array input conversion
  • Linear warmup phase with proper interpolation
  • Decay phase delegation to subclass _decay_value
  • Boundary continuity at warmup_steps (decay starts at step=0 relative to warmup end, returning start_lr)

The use of xp.where ensures correct branching behavior compatible with JIT compilation.


334-362: LGTM!

The _decay_value implementation correctly handles both stepped and smooth exponential decay modes, with proper array API usage and device placement for the decay_rate constant.


462-480: LGTM!

Good handling of edge cases:

  • decay_num_steps=0 returns start_lr (no decay phase)
  • Steps beyond decay_num_steps are clipped to min_lr via xp.where

The cosine annealing formula implementation is mathematically correct and matches the docstring.


40-40: Note: Unused kwargs is intentional for plugin extensibility.

The static analysis flags kwargs as unused (ARG002), but this is intentional for the plugin registry pattern—subclasses may receive additional parameters via super().__init__(**kwargs). No change needed.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@source/tests/tf/test_lr.py`:
- Around line 44-51: The test test_build_returns_tensor hardcodes tf.float64 for
the dtype but the implementation uses GLOBAL_TF_FLOAT_PRECISION; update the
assertion to compare lr_tensor.dtype against GLOBAL_TF_FLOAT_PRECISION (or
tf.as_dtype(GLOBAL_TF_FLOAT_PRECISION) if GLOBAL_TF_FLOAT_PRECISION is a Python
float type) so the test reads the expected dtype from the same source as
LearningRateSchedule.build and will pass for both float32 and float64
configurations.
♻️ Duplicate comments (1)
deepmd/tf/utils/learning_rate.py (1)

50-67: Consider silencing Ruff TRY003 if CI enforces it.

The RuntimeError message is reasonable and clear. If CI enforces TRY003, add # noqa: TRY003 to line 66.

🧹 Nitpick comments (2)
deepmd/tf/utils/learning_rate.py (1)

96-102: Misleading comment about dtype precision.

The comment says "(float64)" but GLOBAL_TF_FLOAT_PRECISION can be either float32 or float64 depending on configuration. Consider removing the hardcoded precision from the comment to avoid confusion.

Suggested fix
         def _lr_value(step: np.ndarray) -> np.ndarray:
-            # Use GLOBAL_TF_FLOAT_PRECISION (float64) for learning rate,
-            # consistent with energy precision in TF backend
+            # Use GLOBAL_TF_FLOAT_PRECISION for learning rate,
+            # consistent with model precision in TF backend
             return np.asarray(
                 base_lr.value(step),
                 dtype=GLOBAL_TF_FLOAT_PRECISION.as_numpy_dtype,
             )
source/tests/tf/test_lr.py (1)

87-104: Consider consolidating with test_tensor_value_matches_base_lr.

This test is functionally identical to test_tensor_value_matches_base_lr—both verify that value() matches base_lr.value() after build(). Consider consolidating them to reduce duplication, or keep both if the intent is to emphasize different aspects (post-build accessibility vs. value correctness).

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2ee5021 and 947e8a6.

📒 Files selected for processing (2)
  • deepmd/tf/utils/learning_rate.py
  • source/tests/tf/test_lr.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/tf/utils/learning_rate.py
📚 Learning: 2025-12-12T13:40:14.334Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-12-12T13:40:14.334Z
Learning: Run core tests with `pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame -v` to validate basic functionality

Applied to files:

  • source/tests/tf/test_lr.py
🧬 Code graph analysis (2)
deepmd/tf/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
  • BaseLR (25-191)
  • value (148-191)
source/tests/tf/test_lr.py (2)
deepmd/dpmodel/utils/learning_rate.py (2)
  • LearningRateExp (195-365)
  • value (148-191)
deepmd/tf/utils/learning_rate.py (5)
  • LearningRateSchedule (21-131)
  • value (110-131)
  • base_lr (51-67)
  • build (69-108)
  • start_lr (39-48)
🪛 Ruff (0.14.11)
deepmd/tf/utils/learning_rate.py

66-66: Avoid specifying long messages outside the exception class

(TRY003)


130-130: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (38)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (9)
deepmd/tf/utils/learning_rate.py (4)

2-18: LGTM!

Imports are well-organized. The GLOBAL_TF_FLOAT_PRECISION import and BaseLR delegation align with the refactored architecture.


21-48: LGTM!

The docstring appropriately documents the tf.numpy_function performance consideration. The start_lr() accessor is clean and straightforward.


110-131: LGTM!

The value() method correctly delegates to BaseLR.value() and handles the built-check. The same TRY003 consideration applies to line 130 as noted earlier.


134-136: LGTM!

Clean export of the renamed public class.

source/tests/tf/test_lr.py (5)

1-20: LGTM!

Clear module docstring establishing test scope. Good separation of concerns—TF wrapper logic is tested here while core algorithms are tested in dpmodel tests.


23-38: LGTM!

Good coverage of pre-build error handling. The assertions verify both the exception type and message content.


53-59: LGTM!

Good verification that the default scheduler type resolves to LearningRateExp.


82-85: LGTM!

Clean test for the start_lr() accessor.


107-108: LGTM!

Standard test runner entry point.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

self.optimizer,
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
lambda step: self.lr_exp.value(step + self.start_step)
/ self.lr_exp.start_lr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest providing a method for accessing the start_lr, rather than directly reads the data of the object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking change Breaking changes that should notify users. Docs enhancement Examples Python

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants