We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4d01933 commit a1b5d0bCopy full SHA for a1b5d0b
s2fft/utils/torch_wrapper.py
@@ -76,7 +76,12 @@ def jax_array_to_torch_tensor(jax_array: jax.Array) -> torch.Tensor:
76
Torch tensor object with equivalent data to `jax_array`.
77
78
"""
79
- return torch.utils.dlpack.from_dlpack(jax_array)
+ try:
80
+ return torch.utils.dlpack.from_dlpack(jax_array)
81
+ except AttributeError:
82
+ # jax.Array instances in earlier JAX versions lack a __dlpack_device__ attribute
83
+ # and require explicitly packing into a DLPack capsule with jax.dlpack.to_dlpack
84
+ return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(jax_array))
85
86
87
def torch_tensor_to_jax_array(torch_tensor: torch.Tensor) -> jax.Array:
0 commit comments