We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c21ebd4 commit 86276f8Copy full SHA for 86276f8
simplexity/utils/pytorch_utils.py
@@ -27,7 +27,6 @@ def jax_to_torch(jax_array: jax.Array) -> torch.Tensor:
27
28
Args:
29
jax_array: JAX array to convert
30
- device: Target PyTorch device (optional, will use JAX array's device if None)
31
32
Returns:
33
PyTorch tensor
@@ -36,8 +35,7 @@ def jax_to_torch(jax_array: jax.Array) -> torch.Tensor:
36
35
ImportError: If JAX or PyTorch is not available
37
"""
38
try:
39
- dlpack_tensor = jax_dlpack.to_dlpack(jax_array)
40
- torch_tensor = torch_dlpack.from_dlpack(dlpack_tensor)
+ torch_tensor = torch_dlpack.from_dlpack(jax_array)
41
return torch_tensor
42
43
except Exception as e:
0 commit comments