Skip to content

Commit 86276f8

Browse files
committed
fix: update JAX DLPack API to remove deprecation warning (#71)
1 parent c21ebd4 commit 86276f8

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

simplexity/utils/pytorch_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def jax_to_torch(jax_array: jax.Array) -> torch.Tensor:
2727
2828
Args:
2929
jax_array: JAX array to convert
30-
device: Target PyTorch device (optional, will use JAX array's device if None)
3130
3231
Returns:
3332
PyTorch tensor
@@ -36,8 +35,7 @@ def jax_to_torch(jax_array: jax.Array) -> torch.Tensor:
3635
ImportError: If JAX or PyTorch is not available
3736
"""
3837
try:
39-
dlpack_tensor = jax_dlpack.to_dlpack(jax_array)
40-
torch_tensor = torch_dlpack.from_dlpack(dlpack_tensor)
38+
torch_tensor = torch_dlpack.from_dlpack(jax_array)
4139
return torch_tensor
4240

4341
except Exception as e:

0 commit comments

Comments
 (0)