-
Notifications
You must be signed in to change notification settings - Fork 586
feat(dpmodel): support Array API learning rate #5143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is useful when the LR is within the JAX JIT compilation - the step is given as a `jnp.ndarray` and NumPy should not be used.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds Array API support to learning rate schedulers, enabling JAX JIT compilation compatibility by allowing the step parameter to be passed as an array (e.g., jnp.ndarray) instead of only a Python integer.
Changes:
- Modified
BaseLR.value()and implementations to accept anxpparameter for Array API namespace - Replaced conditional logic with
xp.where()andxp.clip()to avoid dynamic graphs in JAX JIT - Added comprehensive tests for consistency across NumPy, JAX, PyTorch, and array_api_strict
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| deepmd/dpmodel/utils/learning_rate.py | Updated base class and subclasses to support Array API operations; replaced Python conditionals with array operations for JAX JIT compatibility |
| source/tests/consistent/test_learning_rate.py | Added consistency tests for learning rate schedulers across different array libraries |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughRefactor learning-rate implementations to be array-API compatible: Changes
Sequence Diagram(s)(Skipped) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: Repository UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In @deepmd/dpmodel/utils/learning_rate.py:
- Around line 134-143: The return type of the method value is incorrectly
annotated as np.float64 which breaks non-NumPy backends; change its return
annotation to a backend-agnostic type (e.g., typing.Any) so the signature
becomes def value(self, step: int, xp: Any = np) -> Any:, and ensure typing.Any
is imported/available in the module; keep the implementation unchanged since xp
determines the actual array/float type.
- Around line 96-104: The learning rate .value method uses xp.pow which raises
AttributeError for numpy; update the call in LearningRateExponential.value
(method name: value) to use xp.power instead of xp.pow, and make the same
replacement in LearningRateCosine.value (if present) to ensure compatibility
with numpy and Array API backends; keep the rest of the expression and
xp.asarray(...) logic unchanged.
In @source/tests/consistent/test_learning_rate.py:
- Around line 71-72: Rename the mistyped test function test_arary_api_strict to
test_array_api_strict; update the function definition and any references to it
(e.g., in test discovery or elsewhere calling compare_test_with_ref(xp)) so the
test is discovered and executed correctly while keeping the body that calls
compare_test_with_ref(xp) unchanged.
- Line 49: Rename the misnamed test class TestActivationFunctionConsistent to
TestLearningRateConsistent so the class name matches the tests' purpose; update
the class declaration in source/tests/consistent/test_learning_rate.py (the
class symbol TestActivationFunctionConsistent) to TestLearningRateConsistent and
adjust any references to that class within the file or test runner configuration
if present.
🧹 Nitpick comments (2)
deepmd/dpmodel/utils/learning_rate.py (1)
46-49: Return type annotationnp.float64may be misleading for non-NumPy backends.Based on learnings, methods returning NumPy scalar types should use corresponding NumPy types. However, with the new
xpparameter, the return type can now be a JAX array, PyTorch tensor, etc. Consider updating the type hint to reflect this (e.g.,Anyor a type variable).source/tests/consistent/test_learning_rate.py (1)
40-46: Remove unuseddecay_stepsfrom cosine LR config.The
LearningRateCosineclass doesn't usedecay_steps(it gets silently absorbed by**kwargs). Consider removing it from this test configuration to avoid confusion.Proposed fix
{ "type": "cosine", "start_lr": 1e-3, "stop_lr": 1e-8, - "decay_steps": 1000, "stop_steps": 1000000, },
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/dpmodel/utils/learning_rate.pysource/tests/consistent/test_learning_rate.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (2)
deepmd/dpmodel/utils/learning_rate.py (1)
deepmd/pt/train/training.py (1)
step(756-1128)
source/tests/consistent/test_learning_rate.py (1)
source/tests/consistent/common.py (1)
parameterized(580-640)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (42)
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Agent
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
🔇 Additional comments (2)
deepmd/dpmodel/utils/learning_rate.py (1)
89-94: LGTM -.item()correctly converts to Python scalar.Converting the decay rate to a Python scalar ensures it can be cleanly converted to any backend's array type via
xp.asarray().source/tests/consistent/test_learning_rate.py (1)
50-58: Test structure looks correct.The approach of computing a reference with NumPy and comparing other backends against it is appropriate. The tolerance of
1e-10is reasonable for floating-point comparisons.Note: If the
np.powissue in the main implementation is not fixed, thesetUpmethod will fail when computing the reference value since it usesxp=np.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5143 +/- ##
=======================================
Coverage 81.93% 81.94%
=======================================
Files 712 712
Lines 72895 72903 +8
Branches 3616 3616
=======================================
+ Hits 59729 59739 +10
- Misses 12001 12002 +1
+ Partials 1165 1162 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
…epmd-kit into array-api-learning-rate
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @deepmd/dpmodel/utils/learning_rate.py:
- Around line 6-22: Remove the duplicate and unused imports at the top of the
module: delete the repeated import of array_api_compat (only keep one) and
remove the unused name override from the typing import list; ensure the
remaining imports still include ModuleType, Any, overload, j_get_type, numpy as
np, and Array so functions that reference them (e.g., any use of j_get_type or
Array in this file) continue to work.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/dpmodel/utils/learning_rate.py
🧰 Additional context used
🧠 Learnings (6)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-10-10T22:46:03.419Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-24T01:35:48.050Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/region.py:2-2
Timestamp: 2024-09-24T01:35:48.050Z
Learning: Array API doesn't have a type annotation system yet.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-25T06:36:03.578Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/nlist.py:292-321
Timestamp: 2024-09-25T06:36:03.578Z
Learning: In the Array API, `outer` is only available in `xp.linalg`, not in the main namespace `xp`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-11-23T00:01:06.984Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4406
File: deepmd/dpmodel/array_api.py:51-53
Timestamp: 2024-11-23T00:01:06.984Z
Learning: In `deepmd/dpmodel/array_api.py`, the `__array_api_version__` attribute is guaranteed by the Array API standard to always be present, so error handling for its absence is not required.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (1)
deepmd/dpmodel/utils/learning_rate.py (2)
deepmd/common.py (1)
j_get_type(297-315)deepmd/pt/train/training.py (1)
step(756-1128)
🪛 Ruff (0.14.10)
deepmd/dpmodel/utils/learning_rate.py
14-14: Redefinition of unused array_api_compat from line 13
Remove definition: array_api_compat
(F811)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
🔇 Additional comments (4)
deepmd/dpmodel/utils/learning_rate.py (4)
54-73: Well-designed backend-agnostic namespace selection.The
array_namespacemethod with overloads cleanly handles the dual path for integer vs Array inputs, enabling JAX JIT compatibility while maintaining NumPy support for regular training. The use ofarray_api_compat.is_array_api_objis the correct approach for detecting array types.
113-118: Good use of.item()for backend compatibility.Converting the decay rate to a Python scalar ensures compatibility across different array backends that may not handle NumPy scalar types directly.
120-129: Correctly refactored for JIT compatibility.The switch from conditional
if step_lr < self.min_lrtoxp.clipavoids the dynamic graph issue in JAX JIT while maintaining the same semantics. The use ofxp.powandxp.asarrayfollows Array API conventions.
159-169: Correctly adapted for Array API compatibility.The use of
xp.whereinstead of a conditional statement ensures JIT compatibility. Thexp.asarraywrapping on line 166 and 168 correctly handles both integer and array inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @deepmd/dpmodel/utils/learning_rate.py:
- Around line 60-74: The dispatch in array_namespace currently uses
array_api_compat.is_array_api_obj(step) and can misclassify JAX arrays; change
the logic in the array_namespace method to first check if step is a Python
integer using isinstance(step, numbers.Integral) and return numpy (np) for that
case, otherwise call array_api_compat.array_namespace(step) to obtain the proper
namespace (so JAX arrays are handled by array_api_compat rather than falling
back to NumPy); update imports if needed to reference numbers.Integral and
ensure overloads/types (Array) remain consistent.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/dpmodel/utils/learning_rate.py
🧰 Additional context used
🧠 Learnings (6)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
📚 Learning: 2024-10-10T22:46:03.419Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-24T01:35:48.050Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/region.py:2-2
Timestamp: 2024-09-24T01:35:48.050Z
Learning: Array API doesn't have a type annotation system yet.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-25T06:36:03.578Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/nlist.py:292-321
Timestamp: 2024-09-25T06:36:03.578Z
Learning: In the Array API, `outer` is only available in `xp.linalg`, not in the main namespace `xp`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-11-23T00:01:06.984Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4406
File: deepmd/dpmodel/array_api.py:51-53
Timestamp: 2024-11-23T00:01:06.984Z
Learning: In `deepmd/dpmodel/array_api.py`, the `__array_api_version__` attribute is guaranteed by the Array API standard to always be present, so error handling for its absence is not required.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (1)
deepmd/dpmodel/utils/learning_rate.py (3)
deepmd/common.py (1)
j_get_type(297-315)deepmd/pt/train/training.py (1)
step(756-1128)deepmd/tf/utils/learning_rate.py (1)
start_lr(96-98)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (3)
deepmd/dpmodel/utils/learning_rate.py (3)
120-129: The use ofxp.clip(step_lr, self.min_lr, None)is correct and fully supported. The Array API Standard explicitly supportsmax=Noneto indicate no upper bound, and this pattern is already used elsewhere in the codebase (e.g., inlinear_atomic_model.pywithxp.clip(nlist_larger, 0, None)). The current implementation is portable across NumPy and JAX and aligns with the standard specification.Likely an incorrect or invalid review comment.
159-169: The code is correct as-is; Python scalar arguments toxp.where()are acceptable.The Array API standard explicitly allows Python scalars for
where(condition, x1, x2)parameters, requiring only that the condition be an array and at least one of x1/x2 be an array. The current code satisfies both: the condition is an array, andstep_lris an array. Themin_lrscalar is valid and does not need wrapping withxp.asarray(). Similarly,array-api-strictexplicitly supports Python scalars for x1/x2 (and recently fixed bugs related to this), andarray_api_compatdelegates to backends and is even more permissive. The proposed refactor is unnecessary.Likely an incorrect or invalid review comment.
54-58: No action needed. PEP 604 syntax is correct for this codebase.The project explicitly requires Python >=3.10 in
pyproject.toml, and theint | Arrayunion syntax is already standard throughout the codebase, including indeepmd/dpmodel/array_api.pywhich usesArray = np.ndarray | Any. This is not a critical issue.Likely an incorrect or invalid review comment.
…epmd-kit into array-api-learning-rate
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
source/tests/consistent/test_learning_rate.py (1)
67-68: Unused methodcompare_numpy_with_ref.This method is defined but never invoked by any test in the class. Either add a corresponding test method (e.g.,
test_numpy_consistent_with_ref) or remove this dead code.deepmd/dpmodel/utils/learning_rate.py (1)
147-167: Consider simplifying thexp.cosargument.The inner
xp.asarray(..., device=...)wrapper around the cos argument (lines 159-162) appears redundant. The expressionxp.pi * (xp.astype(step, xp.float64) / self.stop_steps)already produces an array on the same device asstep, so the explicitxp.asarrayconversion is unnecessary.Suggested simplification
step_lr = self.start_lr * ( self.lr_min_factor + 0.5 * (1 - self.lr_min_factor) * ( 1 - + xp.cos( - xp.asarray( - xp.pi * (xp.astype(step, xp.float64) / self.stop_steps), - device=array_api_compat.device(step), - ) - ) + + xp.cos(xp.pi * (xp.astype(step, xp.float64) / self.stop_steps)) ) )Otherwise, the JIT-friendly
xp.whereusage is correct and the overall implementation is sound.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/dpmodel/utils/learning_rate.pysource/tests/consistent/test_learning_rate.py
🧰 Additional context used
🧠 Learnings (6)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-10-10T22:46:03.419Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4204
File: deepmd/dpmodel/fitting/general_fitting.py:426-426
Timestamp: 2024-10-10T22:46:03.419Z
Learning: In `deepmd/dpmodel/fitting/general_fitting.py`, when using the Array API and `array_api_compat`, the `astype` method is not available as an array method. Instead, use `xp.astype()` from the array namespace for type casting.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-24T01:35:48.050Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/region.py:2-2
Timestamp: 2024-09-24T01:35:48.050Z
Learning: Array API doesn't have a type annotation system yet.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-09-25T06:36:03.578Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/nlist.py:292-321
Timestamp: 2024-09-25T06:36:03.578Z
Learning: In the Array API, `outer` is only available in `xp.linalg`, not in the main namespace `xp`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
📚 Learning: 2024-11-23T00:01:06.984Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4406
File: deepmd/dpmodel/array_api.py:51-53
Timestamp: 2024-11-23T00:01:06.984Z
Learning: In `deepmd/dpmodel/array_api.py`, the `__array_api_version__` attribute is guaranteed by the Array API standard to always be present, so error handling for its absence is not required.
Applied to files:
deepmd/dpmodel/utils/learning_rate.pysource/tests/consistent/test_learning_rate.py
🧬 Code graph analysis (1)
deepmd/dpmodel/utils/learning_rate.py (1)
deepmd/pt/train/training.py (1)
step(756-1128)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (41)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (6)
source/tests/consistent/test_learning_rate.py (3)
1-36: LGTM!Imports and conditional backend loading are well-structured with appropriate guards for optional dependencies.
38-65: LGTM!Good parameterization covering both "exp" and "cosine" LR types with consistent test configurations. The setUp correctly establishes a reference value using the default numpy backend.
70-85: LGTM!Test methods are correctly structured with appropriate skip conditions for optional backends. The tolerance of
1e-10is suitable for numerical consistency checks across backends.deepmd/dpmodel/utils/learning_rate.py (3)
6-22: LGTM!Imports are appropriate for array API compatibility. The
Arraytype fromdeepmd.dpmodel.array_apienables proper type hints for the updated method signatures.
54-58: LGTM!Clear abstract method signature update with helpful comment explaining JAX JIT constraints.
98-117: LGTM!The refactoring correctly:
- Uses
.item()to convertdecay_rateto a Python float for storage- Detects array API objects and falls back to numpy for plain ints
- Uses
xp.astype()per array API conventions (based on learnings)- Replaces the conditional
if step_lr < self.min_lrwith JIT-friendlyxp.clip
This is useful when the LR is within the JAX JIT compilation - the step is given as a
jnp.ndarrayand NumPy should not be used.Summary by CodeRabbit
New Features
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.