Skip to content

Commit a1b5d0b

Browse files
committed
Maintain compatibility with older JAX versions
1 parent 4d01933 commit a1b5d0b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

s2fft/utils/torch_wrapper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ def jax_array_to_torch_tensor(jax_array: jax.Array) -> torch.Tensor:
7676
Torch tensor object with equivalent data to `jax_array`.
7777
7878
"""
79-
return torch.utils.dlpack.from_dlpack(jax_array)
79+
try:
80+
return torch.utils.dlpack.from_dlpack(jax_array)
81+
except AttributeError:
82+
# jax.Array instances in earlier JAX versions lack a __dlpack_device__ attribute
83+
# and require explicitly packing into a DLPack capsule with jax.dlpack.to_dlpack
84+
return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(jax_array))
8085

8186

8287
def torch_tensor_to_jax_array(torch_tensor: torch.Tensor) -> jax.Array:

0 commit comments

Comments
 (0)