File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed
Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -107,7 +107,12 @@ def torch_tensor_to_jax_array(torch_tensor: torch.Tensor) -> jax.Array:
107107 # this intended for use when wrapping JAX code detaching tensor from gradient values
108108 # should not be problematic as derivatives will be separately routed via JAX
109109 torch_tensor = torch_tensor .detach ()
110- return jax .dlpack .from_dlpack (torch_tensor )
110+ try :
111+ return jax .dlpack .from_dlpack (torch_tensor )
112+ except TypeError :
113+ # earlier JAX versions require explicitly converting external arrays to
114+ # DLPack capsule before passing to jax.dlpack.from_dlpack
115+ return jax .dlpack .from_dlpack (torch .utils .dlpack .to_dlpack (torch_tensor ))
111116
112117
113118def tree_map_jax_array_to_torch_tensor (
You can’t perform that action at this time.
0 commit comments