Skip to content

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Jan 1, 2026

Summary by CodeRabbit

  • Bug Fixes
    • Improved numerical stability of pairwise distance computations to prevent NaN gradients for zero differences and padded/masked entries; outputs remain unchanged while gradient behavior is now robust for those edge cases.

✏️ Tip: You can customize this high-level summary in your review settings.

Copilot AI review requested due to automatic review settings January 1, 2026 12:34
@github-actions github-actions bot added the Python label Jan 1, 2026
@dosubot dosubot bot added the bug label Jan 1, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a potential NaN gradient issue in the PyTorch implementation of the pairtab atomic model by replacing the standard torch.linalg.norm computation with a safe norm that uses epsilon clamping.

Key Changes:

  • Modified _get_pairwise_dist method to use torch.sqrt(torch.sum(diff * diff, dim=-1, keepdim=True).clamp(min=1e-14)) instead of torch.linalg.norm
  • Added comprehensive documentation in the Notes section explaining when and why this safe norm is needed
  • Added inline comments explaining the epsilon value choice and its purpose

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 1, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Replaced torch.linalg.norm in _get_pairwise_dist with manual computation of sqrt(sum(diff * diff)) and a guarded sqrt that masks zero differences so padded entries produce zero outputs and zero gradients, preventing NaNs for zero-diff vectors.

Changes

Cohort / File(s) Summary
Numerical stability fix
deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Replaced torch.linalg.norm(diff, dim=-1, keepdim=True) with explicit diff_sq = sum(diff * diff, dim=-1, keepdim=True) and a guarded sqrt using a mask so zero diffs yield zero values and zero gradients (handles padding); added explanatory comments.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested reviewers

  • njzjz

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'fix(pt): pairtab' is vague and lacks specificity about what bug is being fixed or what the actual change accomplishes. Provide a more descriptive title that explains the specific issue being fixed, such as 'fix(pt): handle zero gradients in pairtab distance computation' or 'fix(pt): prevent NaN gradients in pairtab padding entries'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a21023 and b646fa1.

📒 Files selected for processing (1)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (2)

395-400: LGTM: Documentation clearly explains the gradient-safe behavior.

The notes accurately describe the zero-distance handling and gradient properties.


407-418: Simplify the masking logic to eliminate redundant squeeze/unsqueeze operations.

The current implementation correctly handles zero-distance cases using the safe-sqrt pattern. However, the squeeze/unsqueeze operations on lines 410–417 are unnecessary. Since diff_sq is created with keepdim=True, you can create the mask directly without squeezing and use it throughout, eliminating the redundant shape manipulations.

Proposed simplification
 diff = coord_r - coord_l
 diff_sq = torch.sum(diff * diff, dim=-1, keepdim=True)

 # When diff is zero, output is zero and gradient is also zero
-mask = diff_sq.squeeze(-1) > 0
+mask = diff_sq > 0
 pairwise_rr = torch.where(
-    mask.unsqueeze(-1),
+    mask,
     torch.sqrt(
-        torch.where(mask.unsqueeze(-1), diff_sq, torch.ones_like(diff_sq))
+        torch.where(mask, diff_sq, torch.ones_like(diff_sq))
     ),
     torch.zeros_like(diff_sq),
 ).squeeze(-1)

This approach aligns with the safe_for_sqrt pattern used in the dpmodel backend and improves code readability without changing the gradient behavior for zero-distance cases.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)

409-416: LGTM! Safe norm implementation correctly prevents NaN gradients.

The implementation correctly computes the Euclidean norm with numerical stability:

  • The epsilon value (1e-14) is well-chosen: small enough to not affect physical distances (atomic distances typically > 0.1 Å, squared > 0.01) yet large enough to prevent gradient issues
  • The clamp on the squared sum (before sqrt) is the right approach to prevent unbounded gradients
  • Inline comments clearly explain the rationale
Optional: Consider defining epsilon as a named constant

For improved maintainability, you could define the epsilon as a class-level constant:

class PairTabAtomicModel(BaseAtomicModel):
    # Epsilon for safe norm computation to prevent NaN gradients
    _SAFE_NORM_EPSILON = 1e-14
    ...

Then use it in the computation:

        pairwise_rr = torch.sqrt(
-            torch.sum(diff * diff, dim=-1, keepdim=True).clamp(min=1e-14)
+            torch.sum(diff * diff, dim=-1, keepdim=True).clamp(min=self._SAFE_NORM_EPSILON)
        ).squeeze(-1)

This makes it easier to adjust the epsilon value consistently if needed in the future. However, this is a minor improvement and can be deferred.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b98f6c5 and e6e35b8.

📒 Files selected for processing (1)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (30)
  • GitHub Check: Agent
  • GitHub Check: CodeQL analysis (python)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
🔇 Additional comments (1)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)

395-402: LGTM! Clear documentation of the numerical stability fix.

The Notes section clearly explains the rationale for the safe norm computation and when zero difference vectors can occur. This will help future maintainers understand why the epsilon clamp is necessary.

@codecov
Copy link

codecov bot commented Jan 1, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.19%. Comparing base (b98f6c5) to head (b646fa1).
⚠️ Report is 7 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5119      +/-   ##
==========================================
+ Coverage   82.15%   82.19%   +0.04%     
==========================================
  Files         709      712       +3     
  Lines       72468    74510    +2042     
  Branches     3616     3616              
==========================================
+ Hits        59535    61245    +1710     
- Misses      11769    12102     +333     
+ Partials     1164     1163       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi requested review from iProzd and njzjz January 5, 2026 03:33
Copy link
Member

@njzjz njzjz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your current implementation changes the result. You may follow this implementation

def safe_for_vector_norm(
x: Any, /, *, axis: Any | None = None, keepdims: bool = False, ord: Any = 2
) -> Any:
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
xp = array_api_compat.array_namespace(x)
mask = xp.sum(xp.square(x), axis=axis, keepdims=True) > 0
if keepdims:
mask_squeezed = mask
else:
mask_squeezed = xp.squeeze(mask, axis=axis)
return xp.where(
mask_squeezed,
xp.linalg.vector_norm(
xp.where(mask, x, xp.ones_like(x)), axis=axis, keepdims=keepdims, ord=ord
),
xp.zeros_like(mask_squeezed, dtype=x.dtype),
)

@OutisLi OutisLi requested a review from njzjz January 5, 2026 11:00
@njzjz njzjz enabled auto-merge January 7, 2026 19:09
@njzjz njzjz closed this Jan 8, 2026
auto-merge was automatically disabled January 8, 2026 02:45

Pull request was closed

@njzjz njzjz reopened this Jan 8, 2026
@njzjz njzjz enabled auto-merge January 8, 2026 02:45
@njzjz njzjz added this pull request to the merge queue Jan 8, 2026
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jan 8, 2026
@njzjz njzjz added this pull request to the merge queue Jan 8, 2026
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jan 8, 2026
@njzjz njzjz added this pull request to the merge queue Jan 8, 2026
Merged via the queue into deepmodeling:master with commit 5f73113 Jan 8, 2026
124 of 128 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants