Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Jan 10, 2026

This is useful when the LR is within the JAX JIT compilation - the step is given as a jnp.ndarray and NumPy should not be used.

Summary by CodeRabbit

  • New Features

    • Learning-rate API now accepts array-like steps and computes schedules using backend-agnostic array operations; exponential and cosine schedulers support array inputs and preserve minimum learning-rate behavior.
  • Tests

    • Added cross-backend consistency tests to validate identical learning-rate outputs across NumPy, PyTorch, JAX and array-api backends.
  • Chores

    • Updated test dependency constraint and bumped default test API version used by strict array-api tests.

✏️ Tip: You can customize this high-level summary in your review settings.

This is useful when the LR is within the JAX JIT compilation - the step is given as a `jnp.ndarray` and NumPy should not be used.
Copilot AI review requested due to automatic review settings January 10, 2026 16:41
@dosubot dosubot bot added the new feature label Jan 10, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Array API support to learning rate schedulers, enabling JAX JIT compilation compatibility by allowing the step parameter to be passed as an array (e.g., jnp.ndarray) instead of only a Python integer.

Changes:

  • Modified BaseLR.value() and implementations to accept an xp parameter for Array API namespace
  • Replaced conditional logic with xp.where() and xp.clip() to avoid dynamic graphs in JAX JIT
  • Added comprehensive tests for consistency across NumPy, JAX, PyTorch, and array_api_strict

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 9 comments.

File Description
deepmd/dpmodel/utils/learning_rate.py Updated base class and subclasses to support Array API operations; replaced Python conditionals with array operations for JAX JIT compatibility
source/tests/consistent/test_learning_rate.py Added consistency tests for learning rate schedulers across different array libraries

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 10, 2026

📝 Walkthrough

Walkthrough

Refactor learning-rate implementations to be array-API compatible: value(step) accepts int | Array and returns Array using an array namespace; add cross-backend consistency tests and bump array-api-strict test dependency/version.

Changes

Cohort / File(s) Summary
Learning rate core
deepmd/dpmodel/utils/learning_rate.py
Change BaseLR.value and concrete value signatures to accept `step: int
Backend consistency tests
source/tests/consistent/test_learning_rate.py
Add TestLearningRateConsistent parameterized tests comparing "exp" and "cosine" LR outputs at step 500000 across NumPy, PyTorch, JAX, and array_api_strict backends (conditional imports, tight atol).
Test dependency pin
pyproject.toml
Update test extras: array-api-strict constraint changed to >=2.2;python_version>="3.9" (remove previous !=2.1.1 exclusion).
array_api_strict setup
source/tests/array_api_strict/__init__.py
Bump default strict API target version from "2023.12" to "2024.12" used in test initialization.

Sequence Diagram(s)

(Skipped)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm
  • OutisLi
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 41.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(dpmodel): support Array API learning rate' clearly and concisely describes the main change: adding Array API support to learning rate functionality in the dpmodel module.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 98a42be and 615f692.

📒 Files selected for processing (2)
  • deepmd/dpmodel/utils/learning_rate.py
  • source/tests/consistent/test_learning_rate.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/dpmodel/utils/learning_rate.py
  • source/tests/consistent/test_learning_rate.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, true, true, false)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In @deepmd/dpmodel/utils/learning_rate.py:
- Around line 134-143: The return type of the method value is incorrectly
annotated as np.float64 which breaks non-NumPy backends; change its return
annotation to a backend-agnostic type (e.g., typing.Any) so the signature
becomes def value(self, step: int, xp: Any = np) -> Any:, and ensure typing.Any
is imported/available in the module; keep the implementation unchanged since xp
determines the actual array/float type.
- Around line 96-104: The learning rate .value method uses xp.pow which raises
AttributeError for numpy; update the call in LearningRateExponential.value
(method name: value) to use xp.power instead of xp.pow, and make the same
replacement in LearningRateCosine.value (if present) to ensure compatibility
with numpy and Array API backends; keep the rest of the expression and
xp.asarray(...) logic unchanged.

In @source/tests/consistent/test_learning_rate.py:
- Around line 71-72: Rename the mistyped test function test_arary_api_strict to
test_array_api_strict; update the function definition and any references to it
(e.g., in test discovery or elsewhere calling compare_test_with_ref(xp)) so the
test is discovered and executed correctly while keeping the body that calls
compare_test_with_ref(xp) unchanged.
- Line 49: Rename the misnamed test class TestActivationFunctionConsistent to
TestLearningRateConsistent so the class name matches the tests' purpose; update
the class declaration in source/tests/consistent/test_learning_rate.py (the
class symbol TestActivationFunctionConsistent) to TestLearningRateConsistent and
adjust any references to that class within the file or test runner configuration
if present.
🧹 Nitpick comments (2)
deepmd/dpmodel/utils/learning_rate.py (1)

46-49: Return type annotation np.float64 may be misleading for non-NumPy backends.

Based on learnings, methods returning NumPy scalar types should use corresponding NumPy types. However, with the new xp parameter, the return type can now be a JAX array, PyTorch tensor, etc. Consider updating the type hint to reflect this (e.g., Any or a type variable).

source/tests/consistent/test_learning_rate.py (1)

40-46: Remove unused decay_steps from cosine LR config.

The LearningRateCosine class doesn't use decay_steps (it gets silently absorbed by **kwargs). Consider removing it from this test configuration to avoid confusion.

Proposed fix
         {
             "type": "cosine",
             "start_lr": 1e-3,
             "stop_lr": 1e-8,
-            "decay_steps": 1000,
             "stop_steps": 1000000,
         },
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9b1df92 and 836648c.

📒 Files selected for processing (2)
  • deepmd/dpmodel/utils/learning_rate.py
  • source/tests/consistent/test_learning_rate.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (2)
deepmd/dpmodel/utils/learning_rate.py (1)
deepmd/pt/train/training.py (1)
  • step (756-1128)
source/tests/consistent/test_learning_rate.py (1)
source/tests/consistent/common.py (1)
  • parameterized (580-640)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (42)
  • GitHub Check: CodeQL analysis (python)
  • GitHub Check: Agent
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
🔇 Additional comments (2)
deepmd/dpmodel/utils/learning_rate.py (1)

89-94: LGTM - .item() correctly converts to Python scalar.

Converting the decay rate to a Python scalar ensures it can be cleanly converted to any backend's array type via xp.asarray().

source/tests/consistent/test_learning_rate.py (1)

50-58: Test structure looks correct.

The approach of computing a reference with NumPy and comparing other backends against it is appropriate. The tolerance of 1e-10 is reasonable for floating-point comparisons.

Note: If the np.pow issue in the main implementation is not fixed, the setUp method will fail when computing the reference value since it uses xp=np.

@codecov
Copy link

codecov bot commented Jan 10, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.94%. Comparing base (9b1df92) to head (615f692).
⚠️ Report is 2 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5143   +/-   ##
=======================================
  Coverage   81.93%   81.94%           
=======================================
  Files         712      712           
  Lines       72895    72903    +8     
  Branches     3616     3616           
=======================================
+ Hits        59729    59739   +10     
- Misses      12001    12002    +1     
+ Partials     1165     1162    -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @deepmd/dpmodel/utils/learning_rate.py:
- Around line 6-22: Remove the duplicate and unused imports at the top of the
module: delete the repeated import of array_api_compat (only keep one) and
remove the unused name override from the typing import list; ensure the
remaining imports still include ModuleType, Any, overload, j_get_type, numpy as
np, and Array so functions that reference them (e.g., any use of j_get_type or
Array in this file) continue to work.
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 84906e8 and 375a096.

📒 Files selected for processing (1)
  • deepmd/dpmodel/utils/learning_rate.py
🧰 Additional context used
🧠 Learnings (6)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-10-10T22:46:03.419Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-24T01:35:48.050Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/region.py:2-2
Timestamp: 2024-09-24T01:35:48.050Z
Learning: Array API doesn't have a type annotation system yet.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-25T06:36:03.578Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/nlist.py:292-321
Timestamp: 2024-09-25T06:36:03.578Z
Learning: In the Array API, `outer` is only available in `xp.linalg`, not in the main namespace `xp`.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-11-23T00:01:06.984Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4406
File: deepmd/dpmodel/array_api.py:51-53
Timestamp: 2024-11-23T00:01:06.984Z
Learning: In `deepmd/dpmodel/array_api.py`, the `__array_api_version__` attribute is guaranteed by the Array API standard to always be present, so error handling for its absence is not required.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (1)
deepmd/dpmodel/utils/learning_rate.py (2)
deepmd/common.py (1)
  • j_get_type (297-315)
deepmd/pt/train/training.py (1)
  • step (756-1128)
🪛 Ruff (0.14.10)
deepmd/dpmodel/utils/learning_rate.py

14-14: Redefinition of unused array_api_compat from line 13

Remove definition: array_api_compat

(F811)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda120, cuda)
🔇 Additional comments (4)
deepmd/dpmodel/utils/learning_rate.py (4)

54-73: Well-designed backend-agnostic namespace selection.

The array_namespace method with overloads cleanly handles the dual path for integer vs Array inputs, enabling JAX JIT compatibility while maintaining NumPy support for regular training. The use of array_api_compat.is_array_api_obj is the correct approach for detecting array types.


113-118: Good use of .item() for backend compatibility.

Converting the decay rate to a Python scalar ensures compatibility across different array backends that may not handle NumPy scalar types directly.


120-129: Correctly refactored for JIT compatibility.

The switch from conditional if step_lr < self.min_lr to xp.clip avoids the dynamic graph issue in JAX JIT while maintaining the same semantics. The use of xp.pow and xp.asarray follows Array API conventions.


159-169: Correctly adapted for Array API compatibility.

The use of xp.where instead of a conditional statement ensures JIT compatibility. The xp.asarray wrapping on line 166 and 168 correctly handles both integer and array inputs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @deepmd/dpmodel/utils/learning_rate.py:
- Around line 60-74: The dispatch in array_namespace currently uses
array_api_compat.is_array_api_obj(step) and can misclassify JAX arrays; change
the logic in the array_namespace method to first check if step is a Python
integer using isinstance(step, numbers.Integral) and return numpy (np) for that
case, otherwise call array_api_compat.array_namespace(step) to obtain the proper
namespace (so JAX arrays are handled by array_api_compat rather than falling
back to NumPy); update imports if needed to reference numbers.Integral and
ensure overloads/types (Array) remain consistent.
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 375a096 and 0efdc3d.

📒 Files selected for processing (1)
  • deepmd/dpmodel/utils/learning_rate.py
🧰 Additional context used
🧠 Learnings (6)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
📚 Learning: 2024-10-10T22:46:03.419Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-24T01:35:48.050Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/region.py:2-2
Timestamp: 2024-09-24T01:35:48.050Z
Learning: Array API doesn't have a type annotation system yet.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-25T06:36:03.578Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/nlist.py:292-321
Timestamp: 2024-09-25T06:36:03.578Z
Learning: In the Array API, `outer` is only available in `xp.linalg`, not in the main namespace `xp`.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-11-23T00:01:06.984Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4406
File: deepmd/dpmodel/array_api.py:51-53
Timestamp: 2024-11-23T00:01:06.984Z
Learning: In `deepmd/dpmodel/array_api.py`, the `__array_api_version__` attribute is guaranteed by the Array API standard to always be present, so error handling for its absence is not required.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (1)
deepmd/dpmodel/utils/learning_rate.py (3)
deepmd/common.py (1)
  • j_get_type (297-315)
deepmd/pt/train/training.py (1)
  • step (756-1128)
deepmd/tf/utils/learning_rate.py (1)
  • start_lr (96-98)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (3)
deepmd/dpmodel/utils/learning_rate.py (3)

120-129: The use of xp.clip(step_lr, self.min_lr, None) is correct and fully supported. The Array API Standard explicitly supports max=None to indicate no upper bound, and this pattern is already used elsewhere in the codebase (e.g., in linear_atomic_model.py with xp.clip(nlist_larger, 0, None)). The current implementation is portable across NumPy and JAX and aligns with the standard specification.

Likely an incorrect or invalid review comment.


159-169: The code is correct as-is; Python scalar arguments to xp.where() are acceptable.

The Array API standard explicitly allows Python scalars for where(condition, x1, x2) parameters, requiring only that the condition be an array and at least one of x1/x2 be an array. The current code satisfies both: the condition is an array, and step_lr is an array. The min_lr scalar is valid and does not need wrapping with xp.asarray(). Similarly, array-api-strict explicitly supports Python scalars for x1/x2 (and recently fixed bugs related to this), and array_api_compat delegates to backends and is even more permissive. The proposed refactor is unnecessary.

Likely an incorrect or invalid review comment.


54-58: No action needed. PEP 604 syntax is correct for this codebase.

The project explicitly requires Python >=3.10 in pyproject.toml, and the int | Array union syntax is already standard throughout the codebase, including in deepmd/dpmodel/array_api.py which uses Array = np.ndarray | Any. This is not a critical issue.

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
source/tests/consistent/test_learning_rate.py (1)

67-68: Unused method compare_numpy_with_ref.

This method is defined but never invoked by any test in the class. Either add a corresponding test method (e.g., test_numpy_consistent_with_ref) or remove this dead code.

deepmd/dpmodel/utils/learning_rate.py (1)

147-167: Consider simplifying the xp.cos argument.

The inner xp.asarray(..., device=...) wrapper around the cos argument (lines 159-162) appears redundant. The expression xp.pi * (xp.astype(step, xp.float64) / self.stop_steps) already produces an array on the same device as step, so the explicit xp.asarray conversion is unnecessary.

Suggested simplification
         step_lr = self.start_lr * (
             self.lr_min_factor
             + 0.5
             * (1 - self.lr_min_factor)
             * (
                 1
-                + xp.cos(
-                    xp.asarray(
-                        xp.pi * (xp.astype(step, xp.float64) / self.stop_steps),
-                        device=array_api_compat.device(step),
-                    )
-                )
+                + xp.cos(xp.pi * (xp.astype(step, xp.float64) / self.stop_steps))
             )
         )

Otherwise, the JIT-friendly xp.where usage is correct and the overall implementation is sound.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0efdc3d and 98a42be.

📒 Files selected for processing (2)
  • deepmd/dpmodel/utils/learning_rate.py
  • source/tests/consistent/test_learning_rate.py
🧰 Additional context used
🧠 Learnings (6)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-10-10T22:46:03.419Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-24T01:35:48.050Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/region.py:2-2
Timestamp: 2024-09-24T01:35:48.050Z
Learning: Array API doesn't have a type annotation system yet.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-25T06:36:03.578Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/nlist.py:292-321
Timestamp: 2024-09-25T06:36:03.578Z
Learning: In the Array API, `outer` is only available in `xp.linalg`, not in the main namespace `xp`.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-11-23T00:01:06.984Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4406
File: deepmd/dpmodel/array_api.py:51-53
Timestamp: 2024-11-23T00:01:06.984Z
Learning: In `deepmd/dpmodel/array_api.py`, the `__array_api_version__` attribute is guaranteed by the Array API standard to always be present, so error handling for its absence is not required.

Applied to files:

  • deepmd/dpmodel/utils/learning_rate.py
  • source/tests/consistent/test_learning_rate.py
🧬 Code graph analysis (1)
deepmd/dpmodel/utils/learning_rate.py (1)
deepmd/pt/train/training.py (1)
  • step (756-1128)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (41)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (6)
source/tests/consistent/test_learning_rate.py (3)

1-36: LGTM!

Imports and conditional backend loading are well-structured with appropriate guards for optional dependencies.


38-65: LGTM!

Good parameterization covering both "exp" and "cosine" LR types with consistent test configurations. The setUp correctly establishes a reference value using the default numpy backend.


70-85: LGTM!

Test methods are correctly structured with appropriate skip conditions for optional backends. The tolerance of 1e-10 is suitable for numerical consistency checks across backends.

deepmd/dpmodel/utils/learning_rate.py (3)

6-22: LGTM!

Imports are appropriate for array API compatibility. The Array type from deepmd.dpmodel.array_api enables proper type hints for the updated method signatures.


54-58: LGTM!

Clear abstract method signature update with helpful comment explaining JAX JIT constraints.


98-117: LGTM!

The refactoring correctly:

  • Uses .item() to convert decay_rate to a Python float for storage
  • Detects array API objects and falls back to numpy for plain ints
  • Uses xp.astype() per array API conventions (based on learnings)
  • Replaces the conditional if step_lr < self.min_lr with JIT-friendly xp.clip

@njzjz njzjz requested a review from iProzd January 10, 2026 20:11
@iProzd iProzd added this pull request to the merge queue Jan 12, 2026
Merged via the queue into deepmodeling:master with commit 854dca8 Jan 12, 2026
70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants