Skip to content

fix: update JAX DLPack API to remove deprecation warning#71

Merged
ealt merged 1 commit intomainfrom
claude/issue-70-20250908-1901
Sep 8, 2025
Merged

fix: update JAX DLPack API to remove deprecation warning#71
ealt merged 1 commit intomainfrom
claude/issue-70-20250908-1901

Conversation

@adamimos
Copy link
Collaborator

@adamimos adamimos commented Sep 8, 2025

The jax.dlpack.to_dlpack function was deprecated in JAX v0.6.0 and removed in v0.7.0.

Updated jax_to_torch function to pass JAX arrays directly to torch.utils.dlpack.from_dlpack, which uses the newer dlpack protocol automatically.

Fixes #70

Generated with Claude Code

The jax.dlpack.to_dlpack function was deprecated in JAX v0.6.0 and removed in v0.7.0.
Updated jax_to_torch function to pass JAX arrays directly to torch.utils.dlpack.from_dlpack,
which uses the newer __dlpack__ protocol automatically.

Fixes #70

Co-authored-by: adamimos <adamimos@users.noreply.github.com>
@claude
Copy link

claude bot commented Sep 8, 2025

Code Review

Overall Assessment: LGTM with minor observations

This is a solid fix that addresses the JAX DLPack API deprecation warning properly.

Code Quality & Best Practices

  • Excellent: The solution follows the newer DLPack protocol correctly by passing JAX arrays directly to torch.utils.dlpack.from_dlpack
  • Good: Maintains the existing fallback mechanism to numpy conversion
  • Good: Preserves all error handling and warning patterns
  • Good: Correctly updated the docstring to remove the obsolete device parameter

Technical Correctness

  • API Migration: Correctly removes the deprecated jax.dlpack.to_dlpack() call
  • Protocol Compliance: Uses the newer dlpack protocol automatically through direct array passing
  • Backward Compatibility: Maintains the same function signature and behavior
  • Error Handling: Preserves existing try/catch with numpy fallback

Performance Considerations

  • Improvement: Eliminates an intermediate DLPack tensor creation, potentially slightly more efficient
  • Maintained: Zero-copy GPU transfers still work as expected
  • No Regression: Fallback path unchanged for cases where DLPack fails

Security & Safety

  • No Security Concerns: Change is purely API migration
  • Type Safety: Maintains existing type annotations
  • Exception Handling: Preserved comprehensive error handling

Test Coverage

  • Good: Existing tests in tests/utils/test_pytorch_utils.py should cover this change
  • Recommendation: Consider running the test suite to verify GPU/CPU conversion still works correctly

Minor Observations

  1. Unused Import: The 'from jax import dlpack as jax_dlpack' import on line 13 is now unused and could be removed
  2. Documentation: The module docstring mentions DLPack but could be updated to reflect the newer protocol usage

Code Style Compliance

  • Line Length: Within 120 character limit
  • Type Hints: Properly maintained
  • Functional Style: Maintains pure function approach
  • Naming: Follows project conventions

Recommendation

This PR should be merged after addressing the unused import. The fix correctly resolves the deprecation warning while maintaining full functionality.

@adamimos
Copy link
Collaborator Author

adamimos commented Sep 8, 2025

Code Review

Overall Assessment: LGTM with minor observations

This is a solid fix that addresses the JAX DLPack API deprecation warning properly.

Code Quality & Best Practices

  • Excellent: The solution follows the newer DLPack protocol correctly by passing JAX arrays directly to torch.utils.dlpack.from_dlpack
  • Good: Maintains the existing fallback mechanism to numpy conversion
  • Good: Preserves all error handling and warning patterns
  • Good: Correctly updated the docstring to remove the obsolete device parameter

Technical Correctness

  • API Migration: Correctly removes the deprecated jax.dlpack.to_dlpack() call
  • Protocol Compliance: Uses the newer dlpack protocol automatically through direct array passing
  • Backward Compatibility: Maintains the same function signature and behavior
  • Error Handling: Preserves existing try/catch with numpy fallback

Performance Considerations

  • Improvement: Eliminates an intermediate DLPack tensor creation, potentially slightly more efficient
  • Maintained: Zero-copy GPU transfers still work as expected
  • No Regression: Fallback path unchanged for cases where DLPack fails

Security & Safety

  • No Security Concerns: Change is purely API migration
  • Type Safety: Maintains existing type annotations
  • Exception Handling: Preserved comprehensive error handling

Test Coverage

  • Good: Existing tests in tests/utils/test_pytorch_utils.py should cover this change
  • Recommendation: Consider running the test suite to verify GPU/CPU conversion still works correctly

Minor Observations

  1. Unused Import: The 'from jax import dlpack as jax_dlpack' import on line 13 is now unused and could be removed
  2. Documentation: The module docstring mentions DLPack but could be updated to reflect the newer protocol usage

Code Style Compliance

  • Line Length: Within 120 character limit
  • Type Hints: Properly maintained
  • Functional Style: Maintains pure function approach
  • Naming: Follows project conventions

Recommendation

This PR should be merged after addressing the unused import. The fix correctly resolves the deprecation warning while maintaining full functionality.

@claude please make changes to address the unused import

@claude
Copy link

claude bot commented Sep 8, 2025

Claude finished @adamimos's task —— View job


Addressing Unused Import

Todo List:

  • Read the current pytorch_utils.py file to confirm the unused import
  • Verify the import usage in the codebase
  • Update comment with findings

Findings:

After reviewing the code more carefully, I need to correct my previous assessment. The from jax import dlpack as jax_dlpack import on line 13 is NOT actually unused.

While the PR correctly removed the usage of jax_dlpack.to_dlpack() from the jax_to_torch function (which was deprecated), the jax_dlpack import is still needed for the torch_to_jax function on line 69:

jax_array = jax_dlpack.from_dlpack(dlpack_tensor)

The jax_dlpack.from_dlpack() function is still the correct way to convert from DLPack tensors to JAX arrays, so this import should remain.

Conclusion:

No changes are needed. My previous code review incorrectly identified this as an unused import. The import is actively used and required for the torch_to_jax function to work properly.

Copy link
Collaborator

@ealt ealt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job, Claude!

@ealt ealt merged commit 26067b5 into main Sep 8, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

API change for DLPack conversion

2 participants