Skip to content

Commit 8d9a2ff

Browse files
committed
More maintaining compatibility with older JAX versions
1 parent a1b5d0b commit 8d9a2ff

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

s2fft/utils/torch_wrapper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ def torch_tensor_to_jax_array(torch_tensor: torch.Tensor) -> jax.Array:
107107
# this intended for use when wrapping JAX code detaching tensor from gradient values
108108
# should not be problematic as derivatives will be separately routed via JAX
109109
torch_tensor = torch_tensor.detach()
110-
return jax.dlpack.from_dlpack(torch_tensor)
110+
try:
111+
return jax.dlpack.from_dlpack(torch_tensor)
112+
except TypeError:
113+
# earlier JAX versions require explicitly converting external arrays to
114+
# DLPack capsule before passing to jax.dlpack.from_dlpack
115+
return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor))
111116

112117

113118
def tree_map_jax_array_to_torch_tensor(

0 commit comments

Comments
 (0)