-
Notifications
You must be signed in to change notification settings - Fork 584
fix(jax): fix support for default fparam #5123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes JAX support for models with default frame parameters by adding the get_default_fparam() method throughout the JAX/dpmodel stack and updating serialization to save/load these values.
- Adds
get_default_fparam()method to return default frame parameters alongside existinghas_default_fparam()check - Updates serialization to persist default_fparam values in HLO and SavedModel formats
- Implements default fparam handling in JAX inference to satisfy XLA's static shape requirements
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/jax/utils/serialization.py | Adds has_default_fparam and default_fparam to serialized model constants |
| deepmd/jax/model/hlo.py | Adds constructor parameters and getter methods for default fparam support |
| deepmd/jax/jax2tf/tfmodel.py | Adds backward-compatible loading and getter methods for default fparam |
| deepmd/jax/jax2tf/serialization.py | Adds TensorFlow functions to expose default fparam in SavedModel |
| deepmd/jax/infer/deep_eval.py | Applies default fparam when no explicit fparam provided and adds has_default_fparam method |
| deepmd/dpmodel/model/make_model.py | Delegates get_default_fparam to atomic model |
| deepmd/dpmodel/fitting/general_fitting.py | Returns the default_fparam attribute |
| deepmd/dpmodel/atomic_model/dp_atomic_model.py | Delegates get_default_fparam to fitting |
| deepmd/dpmodel/atomic_model/base_atomic_model.py | Provides base implementation returning empty list |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds get_default_fparam() getters and has_default_fparam() checkers across model, fitting, JAX inference, TF wrapper/serialization layers; DeepEval now uses a default fparam when fparam is None by tiling it across frames for static-shape requirements. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant DeepEval
participant AtomicModel
participant TFModel
rect rgba(220,240,220,0.25)
note left of DeepEval: Inference entrypoint (new default-fparam flow)
end
Client->>DeepEval: infer(coords, fparam=None)
alt fparam is None and DeepEval.has_default_fparam()
DeepEval->>AtomicModel: has_default_fparam()
AtomicModel-->>DeepEval: True
DeepEval->>AtomicModel: get_default_fparam()
AtomicModel-->>DeepEval: default_fparam (list)
rect rgba(200,220,255,0.18)
note over DeepEval: Tile default_fparam across frames\n(to meet static-shape requirement)
end
DeepEval->>TFModel: call(model_inputs with tiled fparam)
else fparam provided or no default
DeepEval->>TFModel: call(model_inputs with given/None fparam)
end
TFModel-->>Client: inference result
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Fix all issues with AI Agents 🤖
In @deepmd/dpmodel/atomic_model/base_atomic_model.py:
- Around line 98-100: get_default_fparam currently declares the wrong return
type and returns an empty list even when has_default_fparam() is False; change
its signature to return list[float] | None and update the implementation to
return None when no default frame parameters exist (consistent with
has_default_fparam()), otherwise return a list[float] of defaults; update
references if any to expect list[float] | None.
In @deepmd/dpmodel/atomic_model/dp_atomic_model.py:
- Around line 243-245: The return type annotation on get_default_fparam is
incorrect; update the signature of get_default_fparam (method name:
get_default_fparam) to return list[float] | None instead of list[int] | None to
match the actual return value from self.fitting.get_default_fparam(), which is
defined as list[float] | None in general_fitting.py and used with tf.double
elsewhere.
In @deepmd/dpmodel/fitting/general_fitting.py:
- Around line 307-309: The return type annotation for get_default_fparam is
incorrect; change it from list[int] | None to list[float] | None to match the
attribute self.default_fparam (which is defined as list[float] | None) and
reflect that frame parameters are floats; update the function signature of
get_default_fparam accordingly so the annotation and returned value are
consistent with the default_fparam attribute.
In @deepmd/dpmodel/model/make_model.py:
- Around line 570-572: Update the return type annotation of get_default_fparam
to reflect floating-point frame parameters: change its signature from list[int]
| None to list[float] | None and ensure it still returns the result of
self.atomic_model.get_default_fparam() (which also returns list[float] | None).
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
deepmd/dpmodel/atomic_model/base_atomic_model.pydeepmd/dpmodel/atomic_model/dp_atomic_model.pydeepmd/dpmodel/fitting/general_fitting.pydeepmd/dpmodel/model/make_model.pydeepmd/jax/infer/deep_eval.pydeepmd/jax/jax2tf/serialization.pydeepmd/jax/jax2tf/tfmodel.pydeepmd/jax/model/hlo.pydeepmd/jax/utils/serialization.py
🧰 Additional context used
🧬 Code graph analysis (6)
deepmd/jax/infer/deep_eval.py (4)
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
has_default_fparam(94-96)get_default_fparam(98-100)deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
has_default_fparam(239-241)get_default_fparam(243-245)deepmd/dpmodel/model/make_model.py (2)
has_default_fparam(566-568)get_default_fparam(570-572)deepmd/dpmodel/infer/deep_eval.py (1)
has_default_fparam(128-130)
deepmd/dpmodel/model/make_model.py (8)
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
get_default_fparam(98-100)deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
get_default_fparam(243-245)deepmd/dpmodel/fitting/general_fitting.py (1)
get_default_fparam(307-309)deepmd/jax/jax2tf/serialization.py (1)
get_default_fparam(330-335)deepmd/jax/jax2tf/tfmodel.py (1)
get_default_fparam(348-350)deepmd/jax/model/hlo.py (1)
get_default_fparam(340-342)deepmd/pt/model/model/make_model.py (1)
get_default_fparam(536-537)deepmd/pt/model/task/fitting.py (1)
get_default_fparam(644-645)
deepmd/dpmodel/atomic_model/base_atomic_model.py (7)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
get_default_fparam(243-245)deepmd/dpmodel/fitting/general_fitting.py (1)
get_default_fparam(307-309)deepmd/dpmodel/model/make_model.py (1)
get_default_fparam(570-572)deepmd/jax/jax2tf/serialization.py (1)
get_default_fparam(330-335)deepmd/jax/jax2tf/tfmodel.py (1)
get_default_fparam(348-350)deepmd/jax/model/hlo.py (1)
get_default_fparam(340-342)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
get_default_fparam(384-385)
deepmd/jax/utils/serialization.py (9)
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
has_default_fparam(94-96)get_default_fparam(98-100)deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
has_default_fparam(239-241)get_default_fparam(243-245)deepmd/dpmodel/fitting/general_fitting.py (2)
has_default_fparam(303-305)get_default_fparam(307-309)deepmd/dpmodel/model/make_model.py (2)
has_default_fparam(566-568)get_default_fparam(570-572)deepmd/jax/infer/deep_eval.py (1)
has_default_fparam(445-447)deepmd/jax/jax2tf/serialization.py (2)
has_default_fparam(324-325)get_default_fparam(330-335)deepmd/jax/jax2tf/tfmodel.py (2)
has_default_fparam(344-346)get_default_fparam(348-350)deepmd/jax/model/hlo.py (2)
has_default_fparam(336-338)get_default_fparam(340-342)deepmd/dpmodel/infer/deep_eval.py (1)
has_default_fparam(128-130)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (9)
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
get_default_fparam(98-100)deepmd/dpmodel/fitting/general_fitting.py (1)
get_default_fparam(307-309)deepmd/dpmodel/model/make_model.py (1)
get_default_fparam(570-572)deepmd/jax/jax2tf/serialization.py (1)
get_default_fparam(330-335)deepmd/jax/jax2tf/tfmodel.py (1)
get_default_fparam(348-350)deepmd/jax/model/hlo.py (1)
get_default_fparam(340-342)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
get_default_fparam(384-385)deepmd/pt/model/model/make_model.py (1)
get_default_fparam(536-537)deepmd/pt/model/task/fitting.py (1)
get_default_fparam(644-645)
deepmd/jax/model/hlo.py (7)
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
has_default_fparam(94-96)get_default_fparam(98-100)deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
has_default_fparam(239-241)get_default_fparam(243-245)deepmd/dpmodel/fitting/general_fitting.py (2)
has_default_fparam(303-305)get_default_fparam(307-309)deepmd/dpmodel/model/make_model.py (2)
has_default_fparam(566-568)get_default_fparam(570-572)deepmd/jax/infer/deep_eval.py (1)
has_default_fparam(445-447)deepmd/jax/jax2tf/serialization.py (2)
has_default_fparam(324-325)get_default_fparam(330-335)deepmd/jax/jax2tf/tfmodel.py (2)
has_default_fparam(344-346)get_default_fparam(348-350)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (30)
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Agent
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (7)
deepmd/jax/utils/serialization.py (1)
136-137: LGTM! Serialization of default fparam metadata.The additions correctly serialize the default frame parameter metadata alongside other model constants, enabling consumers of the serialized model to query this information.
deepmd/jax/infer/deep_eval.py (1)
445-447: LGTM! Consistent delegation pattern.The
has_default_fparam()method correctly delegates to the underlying model, maintaining consistency with the established pattern across dpmodel components.deepmd/jax/jax2tf/tfmodel.py (2)
72-80: LGTM! Proper backward compatibility handling.The use of
hasattrchecks ensures compatibility with models serialized before v3.1.2, gracefully falling back toFalseandNonewhen the attributes are unavailable. The inline comment clearly documents this version dependency.
344-350: LGTM! Clean accessor implementation.The accessor methods correctly expose the internal state initialized during
__init__, maintaining consistency with the pattern established across the codebase.deepmd/jax/model/hlo.py (2)
61-86: LGTM! Backward-compatible parameter addition.The new parameters
has_default_fparamanddefault_fparamhave sensible defaults (FalseandNone) that maintain backward compatibility during deserialization. The inline comment documents the version this feature was introduced.
336-342: LGTM! Straightforward accessor implementation.The accessor methods correctly expose the default frame parameter state, aligning with the public API established across the JAX and dpmodel layers.
deepmd/jax/jax2tf/serialization.py (1)
323-337: LGTM! Proper serialization of default fparam attributes.The implementation correctly:
- Wraps the accessors in
tf.functionfor TensorFlow serialization- Handles the
Nonecase by returning an empty tensor with the appropriatetf.doubledtype- Maintains consistency with other attribute serialization patterns in this file
Co-authored-by: Copilot <[email protected]> Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Fix all issues with AI Agents 🤖
In @deepmd/jax/jax2tf/tfmodel.py:
- Around line 72-80: Update the incorrect version note: change the comment "No
attrs before v3.1.2" in the block that inspects self.model.has_default_fparam
and self.model.get_default_fparam to reflect the correct version "new in v3.1.1"
so it matches deepmd/jax/model/hlo.py; the relevant symbols to locate are the
hasattr checks for has_default_fparam and get_default_fparam and the assignments
to self._has_default_fparam and self.default_fparam.
♻️ Duplicate comments (3)
deepmd/dpmodel/fitting/general_fitting.py (1)
307-309: LGTM! Type annotation is correct.The return type annotation
list[float] | Nonecorrectly matches the type ofself.default_fparam(line 130). The implementation is straightforward and consistent with similar accessors across the codebase.deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
243-245: LGTM! Delegation pattern is correct.The method correctly delegates to
self.fitting.get_default_fparam()and the return type annotationlist[float] | Nonematches the fitting's method signature.deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
98-100: LGTM! Base implementation is correct.The method correctly returns
Noneto indicate no default frame parameters exist, which is consistent withhas_default_fparam()returningFalse(lines 94-96). The return type annotationlist[float] | Noneis correct.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
deepmd/dpmodel/atomic_model/base_atomic_model.pydeepmd/dpmodel/atomic_model/dp_atomic_model.pydeepmd/dpmodel/fitting/general_fitting.pydeepmd/dpmodel/model/make_model.pydeepmd/jax/jax2tf/tfmodel.pydeepmd/jax/model/hlo.py
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/dpmodel/model/make_model.py
- deepmd/jax/model/hlo.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/dpmodel/atomic_model/dp_atomic_model.pydeepmd/jax/jax2tf/tfmodel.pydeepmd/dpmodel/fitting/general_fitting.py
🧬 Code graph analysis (2)
deepmd/jax/jax2tf/tfmodel.py (6)
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
has_default_fparam(94-96)get_default_fparam(98-100)deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
has_default_fparam(239-241)get_default_fparam(243-245)deepmd/dpmodel/fitting/general_fitting.py (2)
has_default_fparam(303-305)get_default_fparam(307-309)deepmd/dpmodel/model/make_model.py (2)
has_default_fparam(566-568)get_default_fparam(570-572)deepmd/jax/infer/deep_eval.py (1)
has_default_fparam(445-447)deepmd/jax/jax2tf/serialization.py (2)
has_default_fparam(324-325)get_default_fparam(330-335)
deepmd/dpmodel/fitting/general_fitting.py (7)
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
get_default_fparam(98-100)deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
get_default_fparam(243-245)deepmd/dpmodel/model/make_model.py (1)
get_default_fparam(570-572)deepmd/jax/jax2tf/tfmodel.py (1)
get_default_fparam(348-350)deepmd/jax/model/hlo.py (1)
get_default_fparam(340-342)deepmd/jax/jax2tf/serialization.py (1)
get_default_fparam(330-335)deepmd/pt/model/task/fitting.py (1)
get_default_fparam(644-645)
🔇 Additional comments (1)
deepmd/jax/jax2tf/tfmodel.py (1)
344-350: LGTM! Public API methods are correctly implemented.The
has_default_fparam()andget_default_fparam()methods correctly expose the internal state initialized in lines 72-80. The return type annotationlist[float] | Noneis consistent with other implementations across the codebase.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5123 +/- ##
==========================================
- Coverage 82.15% 82.14% -0.01%
==========================================
Files 709 709
Lines 72468 72509 +41
Branches 3616 3615 -1
==========================================
+ Hits 59535 59565 +30
- Misses 11769 11780 +11
Partials 1164 1164 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.