Skip to content
Discussion options

You must be logged in to vote

In general, dlpack conversions should be zero-copy. You can confirm this by using JAX's unsafe_buffer_pointer() method, which returns an integer representation of the memory address of the array's memory buffer:

import torch
import jax

def jax2torch(x):
    return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))

def torch2jax(x):
    return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))

x = jax.numpy.linspace(0, 10, 1000)

x_torch = jax2torch(x)
x_jax = torch2jax(x_torch)

assert x.unsafe_buffer_pointer() == x_jax.unsafe_buffer_pointer()

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@fangwei123456
Comment options

Answer selected by fangwei123456
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants