Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Jan 5, 2026

Summary by CodeRabbit

  • New Features
    • New APIs to check for and retrieve model default frame parameters.
    • Models can expose built-in default frame parameters when none are provided.
    • Inference now automatically uses exposed defaults when explicit parameters are missing.
    • Serialized TensorFlow exports and JAX/TF backends expose accessors for these defaults.

✏️ Tip: You can customize this high-level summary in your review settings.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes JAX support for models with default frame parameters by adding the get_default_fparam() method throughout the JAX/dpmodel stack and updating serialization to save/load these values.

  • Adds get_default_fparam() method to return default frame parameters alongside existing has_default_fparam() check
  • Updates serialization to persist default_fparam values in HLO and SavedModel formats
  • Implements default fparam handling in JAX inference to satisfy XLA's static shape requirements

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
deepmd/jax/utils/serialization.py Adds has_default_fparam and default_fparam to serialized model constants
deepmd/jax/model/hlo.py Adds constructor parameters and getter methods for default fparam support
deepmd/jax/jax2tf/tfmodel.py Adds backward-compatible loading and getter methods for default fparam
deepmd/jax/jax2tf/serialization.py Adds TensorFlow functions to expose default fparam in SavedModel
deepmd/jax/infer/deep_eval.py Applies default fparam when no explicit fparam provided and adds has_default_fparam method
deepmd/dpmodel/model/make_model.py Delegates get_default_fparam to atomic model
deepmd/dpmodel/fitting/general_fitting.py Returns the default_fparam attribute
deepmd/dpmodel/atomic_model/dp_atomic_model.py Delegates get_default_fparam to fitting
deepmd/dpmodel/atomic_model/base_atomic_model.py Provides base implementation returning empty list

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 5, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds get_default_fparam() getters and has_default_fparam() checkers across model, fitting, JAX inference, TF wrapper/serialization layers; DeepEval now uses a default fparam when fparam is None by tiling it across frames for static-shape requirements.

Changes

Cohort / File(s) Summary
Atomic & fitting API
deepmd/dpmodel/atomic_model/base_atomic_model.py, deepmd/dpmodel/atomic_model/dp_atomic_model.py, deepmd/dpmodel/fitting/general_fitting.py, deepmd/dpmodel/model/make_model.py
Added `get_default_fparam() -> list[float]
JAX inference
deepmd/jax/infer/deep_eval.py
Added has_default_fparam() and logic to fetch get_default_fparam() when fparam is None, then tile it across frames to satisfy JAX static shapes.
TF model wrapper & HLO
deepmd/jax/jax2tf/tfmodel.py, deepmd/jax/model/hlo.py
TF wrapper/HLO store and expose _has_default_fparam / default_fparam; added has_default_fparam() and get_default_fparam() public accessors; HLO ctor accepts has_default_fparam and default_fparam.
TF serialization
deepmd/jax/jax2tf/serialization.py
Attach tf.function-wrapped has_default_fparam() and get_default_fparam() to saved tf_model; get_default_fparam() returns empty tensor when None.
Serialization utils
deepmd/jax/utils/serialization.py
Added has_default_fparam and default_fparam to model constants emitted during serialization.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant DeepEval
    participant AtomicModel
    participant TFModel
    rect rgba(220,240,220,0.25)
        note left of DeepEval: Inference entrypoint (new default-fparam flow)
    end

    Client->>DeepEval: infer(coords, fparam=None)
    alt fparam is None and DeepEval.has_default_fparam()
        DeepEval->>AtomicModel: has_default_fparam()
        AtomicModel-->>DeepEval: True
        DeepEval->>AtomicModel: get_default_fparam()
        AtomicModel-->>DeepEval: default_fparam (list)
        rect rgba(200,220,255,0.18)
            note over DeepEval: Tile default_fparam across frames\n(to meet static-shape requirement)
        end
        DeepEval->>TFModel: call(model_inputs with tiled fparam)
    else fparam provided or no default
        DeepEval->>TFModel: call(model_inputs with given/None fparam)
    end
    TFModel-->>Client: inference result
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.42% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main change: adding support for default fparam in the JAX backend by introducing getter methods across multiple model classes.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Fix all issues with AI Agents 🤖
In @deepmd/dpmodel/atomic_model/base_atomic_model.py:
- Around line 98-100: get_default_fparam currently declares the wrong return
type and returns an empty list even when has_default_fparam() is False; change
its signature to return list[float] | None and update the implementation to
return None when no default frame parameters exist (consistent with
has_default_fparam()), otherwise return a list[float] of defaults; update
references if any to expect list[float] | None.

In @deepmd/dpmodel/atomic_model/dp_atomic_model.py:
- Around line 243-245: The return type annotation on get_default_fparam is
incorrect; update the signature of get_default_fparam (method name:
get_default_fparam) to return list[float] | None instead of list[int] | None to
match the actual return value from self.fitting.get_default_fparam(), which is
defined as list[float] | None in general_fitting.py and used with tf.double
elsewhere.

In @deepmd/dpmodel/fitting/general_fitting.py:
- Around line 307-309: The return type annotation for get_default_fparam is
incorrect; change it from list[int] | None to list[float] | None to match the
attribute self.default_fparam (which is defined as list[float] | None) and
reflect that frame parameters are floats; update the function signature of
get_default_fparam accordingly so the annotation and returned value are
consistent with the default_fparam attribute.

In @deepmd/dpmodel/model/make_model.py:
- Around line 570-572: Update the return type annotation of get_default_fparam
to reflect floating-point frame parameters: change its signature from list[int]
| None to list[float] | None and ensure it still returns the result of
self.atomic_model.get_default_fparam() (which also returns list[float] | None).
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b98f6c5 and 1acfa9d.

📒 Files selected for processing (9)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py
  • deepmd/dpmodel/fitting/general_fitting.py
  • deepmd/dpmodel/model/make_model.py
  • deepmd/jax/infer/deep_eval.py
  • deepmd/jax/jax2tf/serialization.py
  • deepmd/jax/jax2tf/tfmodel.py
  • deepmd/jax/model/hlo.py
  • deepmd/jax/utils/serialization.py
🧰 Additional context used
🧬 Code graph analysis (6)
deepmd/jax/infer/deep_eval.py (4)
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
  • has_default_fparam (94-96)
  • get_default_fparam (98-100)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
  • has_default_fparam (239-241)
  • get_default_fparam (243-245)
deepmd/dpmodel/model/make_model.py (2)
  • has_default_fparam (566-568)
  • get_default_fparam (570-572)
deepmd/dpmodel/infer/deep_eval.py (1)
  • has_default_fparam (128-130)
deepmd/dpmodel/model/make_model.py (8)
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
  • get_default_fparam (98-100)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
  • get_default_fparam (243-245)
deepmd/dpmodel/fitting/general_fitting.py (1)
  • get_default_fparam (307-309)
deepmd/jax/jax2tf/serialization.py (1)
  • get_default_fparam (330-335)
deepmd/jax/jax2tf/tfmodel.py (1)
  • get_default_fparam (348-350)
deepmd/jax/model/hlo.py (1)
  • get_default_fparam (340-342)
deepmd/pt/model/model/make_model.py (1)
  • get_default_fparam (536-537)
deepmd/pt/model/task/fitting.py (1)
  • get_default_fparam (644-645)
deepmd/dpmodel/atomic_model/base_atomic_model.py (7)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
  • get_default_fparam (243-245)
deepmd/dpmodel/fitting/general_fitting.py (1)
  • get_default_fparam (307-309)
deepmd/dpmodel/model/make_model.py (1)
  • get_default_fparam (570-572)
deepmd/jax/jax2tf/serialization.py (1)
  • get_default_fparam (330-335)
deepmd/jax/jax2tf/tfmodel.py (1)
  • get_default_fparam (348-350)
deepmd/jax/model/hlo.py (1)
  • get_default_fparam (340-342)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • get_default_fparam (384-385)
deepmd/jax/utils/serialization.py (9)
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
  • has_default_fparam (94-96)
  • get_default_fparam (98-100)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
  • has_default_fparam (239-241)
  • get_default_fparam (243-245)
deepmd/dpmodel/fitting/general_fitting.py (2)
  • has_default_fparam (303-305)
  • get_default_fparam (307-309)
deepmd/dpmodel/model/make_model.py (2)
  • has_default_fparam (566-568)
  • get_default_fparam (570-572)
deepmd/jax/infer/deep_eval.py (1)
  • has_default_fparam (445-447)
deepmd/jax/jax2tf/serialization.py (2)
  • has_default_fparam (324-325)
  • get_default_fparam (330-335)
deepmd/jax/jax2tf/tfmodel.py (2)
  • has_default_fparam (344-346)
  • get_default_fparam (348-350)
deepmd/jax/model/hlo.py (2)
  • has_default_fparam (336-338)
  • get_default_fparam (340-342)
deepmd/dpmodel/infer/deep_eval.py (1)
  • has_default_fparam (128-130)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (9)
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
  • get_default_fparam (98-100)
deepmd/dpmodel/fitting/general_fitting.py (1)
  • get_default_fparam (307-309)
deepmd/dpmodel/model/make_model.py (1)
  • get_default_fparam (570-572)
deepmd/jax/jax2tf/serialization.py (1)
  • get_default_fparam (330-335)
deepmd/jax/jax2tf/tfmodel.py (1)
  • get_default_fparam (348-350)
deepmd/jax/model/hlo.py (1)
  • get_default_fparam (340-342)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • get_default_fparam (384-385)
deepmd/pt/model/model/make_model.py (1)
  • get_default_fparam (536-537)
deepmd/pt/model/task/fitting.py (1)
  • get_default_fparam (644-645)
deepmd/jax/model/hlo.py (7)
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
  • has_default_fparam (94-96)
  • get_default_fparam (98-100)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
  • has_default_fparam (239-241)
  • get_default_fparam (243-245)
deepmd/dpmodel/fitting/general_fitting.py (2)
  • has_default_fparam (303-305)
  • get_default_fparam (307-309)
deepmd/dpmodel/model/make_model.py (2)
  • has_default_fparam (566-568)
  • get_default_fparam (570-572)
deepmd/jax/infer/deep_eval.py (1)
  • has_default_fparam (445-447)
deepmd/jax/jax2tf/serialization.py (2)
  • has_default_fparam (324-325)
  • get_default_fparam (330-335)
deepmd/jax/jax2tf/tfmodel.py (2)
  • has_default_fparam (344-346)
  • get_default_fparam (348-350)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (30)
  • GitHub Check: CodeQL analysis (python)
  • GitHub Check: Agent
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (7)
deepmd/jax/utils/serialization.py (1)

136-137: LGTM! Serialization of default fparam metadata.

The additions correctly serialize the default frame parameter metadata alongside other model constants, enabling consumers of the serialized model to query this information.

deepmd/jax/infer/deep_eval.py (1)

445-447: LGTM! Consistent delegation pattern.

The has_default_fparam() method correctly delegates to the underlying model, maintaining consistency with the established pattern across dpmodel components.

deepmd/jax/jax2tf/tfmodel.py (2)

72-80: LGTM! Proper backward compatibility handling.

The use of hasattr checks ensures compatibility with models serialized before v3.1.2, gracefully falling back to False and None when the attributes are unavailable. The inline comment clearly documents this version dependency.


344-350: LGTM! Clean accessor implementation.

The accessor methods correctly expose the internal state initialized during __init__, maintaining consistency with the pattern established across the codebase.

deepmd/jax/model/hlo.py (2)

61-86: LGTM! Backward-compatible parameter addition.

The new parameters has_default_fparam and default_fparam have sensible defaults (False and None) that maintain backward compatibility during deserialization. The inline comment documents the version this feature was introduced.


336-342: LGTM! Straightforward accessor implementation.

The accessor methods correctly expose the default frame parameter state, aligning with the public API established across the JAX and dpmodel layers.

deepmd/jax/jax2tf/serialization.py (1)

323-337: LGTM! Proper serialization of default fparam attributes.

The implementation correctly:

  • Wraps the accessors in tf.function for TensorFlow serialization
  • Handles the None case by returning an empty tensor with the appropriate tf.double dtype
  • Maintains consistency with other attribute serialization patterns in this file

Co-authored-by: Copilot <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Fix all issues with AI Agents 🤖
In @deepmd/jax/jax2tf/tfmodel.py:
- Around line 72-80: Update the incorrect version note: change the comment "No
attrs before v3.1.2" in the block that inspects self.model.has_default_fparam
and self.model.get_default_fparam to reflect the correct version "new in v3.1.1"
so it matches deepmd/jax/model/hlo.py; the relevant symbols to locate are the
hasattr checks for has_default_fparam and get_default_fparam and the assignments
to self._has_default_fparam and self.default_fparam.
♻️ Duplicate comments (3)
deepmd/dpmodel/fitting/general_fitting.py (1)

307-309: LGTM! Type annotation is correct.

The return type annotation list[float] | None correctly matches the type of self.default_fparam (line 130). The implementation is straightforward and consistent with similar accessors across the codebase.

deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)

243-245: LGTM! Delegation pattern is correct.

The method correctly delegates to self.fitting.get_default_fparam() and the return type annotation list[float] | None matches the fitting's method signature.

deepmd/dpmodel/atomic_model/base_atomic_model.py (1)

98-100: LGTM! Base implementation is correct.

The method correctly returns None to indicate no default frame parameters exist, which is consistent with has_default_fparam() returning False (lines 94-96). The return type annotation list[float] | None is correct.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1acfa9d and 9cb5af0.

📒 Files selected for processing (6)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py
  • deepmd/dpmodel/fitting/general_fitting.py
  • deepmd/dpmodel/model/make_model.py
  • deepmd/jax/jax2tf/tfmodel.py
  • deepmd/jax/model/hlo.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/dpmodel/model/make_model.py
  • deepmd/jax/model/hlo.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.

Applied to files:

  • deepmd/dpmodel/atomic_model/dp_atomic_model.py
  • deepmd/jax/jax2tf/tfmodel.py
  • deepmd/dpmodel/fitting/general_fitting.py
🧬 Code graph analysis (2)
deepmd/jax/jax2tf/tfmodel.py (6)
deepmd/dpmodel/atomic_model/base_atomic_model.py (2)
  • has_default_fparam (94-96)
  • get_default_fparam (98-100)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (2)
  • has_default_fparam (239-241)
  • get_default_fparam (243-245)
deepmd/dpmodel/fitting/general_fitting.py (2)
  • has_default_fparam (303-305)
  • get_default_fparam (307-309)
deepmd/dpmodel/model/make_model.py (2)
  • has_default_fparam (566-568)
  • get_default_fparam (570-572)
deepmd/jax/infer/deep_eval.py (1)
  • has_default_fparam (445-447)
deepmd/jax/jax2tf/serialization.py (2)
  • has_default_fparam (324-325)
  • get_default_fparam (330-335)
deepmd/dpmodel/fitting/general_fitting.py (7)
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)
  • get_default_fparam (98-100)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
  • get_default_fparam (243-245)
deepmd/dpmodel/model/make_model.py (1)
  • get_default_fparam (570-572)
deepmd/jax/jax2tf/tfmodel.py (1)
  • get_default_fparam (348-350)
deepmd/jax/model/hlo.py (1)
  • get_default_fparam (340-342)
deepmd/jax/jax2tf/serialization.py (1)
  • get_default_fparam (330-335)
deepmd/pt/model/task/fitting.py (1)
  • get_default_fparam (644-645)
🔇 Additional comments (1)
deepmd/jax/jax2tf/tfmodel.py (1)

344-350: LGTM! Public API methods are correctly implemented.

The has_default_fparam() and get_default_fparam() methods correctly expose the internal state initialized in lines 72-80. The return type annotation list[float] | None is consistent with other implementations across the codebase.

@codecov
Copy link

codecov bot commented Jan 5, 2026

Codecov Report

❌ Patch coverage is 65.85366% with 14 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.14%. Comparing base (b98f6c5) to head (9cb5af0).
⚠️ Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/jax/jax2tf/serialization.py 54.54% 5 Missing ⚠️
deepmd/jax/infer/deep_eval.py 33.33% 4 Missing ⚠️
deepmd/jax/jax2tf/tfmodel.py 70.00% 3 Missing ⚠️
deepmd/dpmodel/atomic_model/base_atomic_model.py 50.00% 1 Missing ⚠️
deepmd/jax/model/hlo.py 83.33% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5123      +/-   ##
==========================================
- Coverage   82.15%   82.14%   -0.01%     
==========================================
  Files         709      709              
  Lines       72468    72509      +41     
  Branches     3616     3615       -1     
==========================================
+ Hits        59535    59565      +30     
- Misses      11769    11780      +11     
  Partials     1164     1164              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz changed the title fix(jax): fix support with default fparam fix(jax): fix support for default fparam Jan 5, 2026
@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Jan 6, 2026
@njzjz njzjz added this pull request to the merge queue Jan 6, 2026
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jan 6, 2026
@njzjz njzjz added this pull request to the merge queue Jan 7, 2026
Merged via the queue into deepmodeling:master with commit be3b876 Jan 7, 2026
58 checks passed
@njzjz njzjz deleted the jax-has-default-fparam branch January 7, 2026 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants