Will the tensor data be copied when converting tensors between pytorch and jax? #18765
Answered
by
jakevdp
fangwei123456
asked this question in
Q&A
-
Hi, I want to use jax to accelerate functions in pytorch. So, I need to convert tensors between two frameworks. I use the following codes: def jax2torch(x):
return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
def torch2jax(x):
return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x)) Will the tensor data be copied (very slow) when converting tensors between pytorch and jax? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Dec 1, 2023
Replies: 1 comment 1 reply
-
In general, dlpack conversions should be zero-copy. You can confirm this by using JAX's import torch
import jax
def jax2torch(x):
return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
def torch2jax(x):
return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
x = jax.numpy.linspace(0, 10, 1000)
x_torch = jax2torch(x)
x_jax = torch2jax(x_torch)
assert x.unsafe_buffer_pointer() == x_jax.unsafe_buffer_pointer() |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
fangwei123456
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In general, dlpack conversions should be zero-copy. You can confirm this by using JAX's
unsafe_buffer_pointer()
method, which returns an integer representation of the memory address of the array's memory buffer: