Skip to content
Discussion options

You must be logged in to vote

Hi - this looks like an unfortunate design choice made by numpy: the numpy.from_dlpack function does not accept a dlpack object directly, but rather accepts an object with a __dlpack__() method that returns a dlpack buffer. You can work around this issue as follows:

import jax
import jax.dlpack as jdp
import jax.numpy as jnp
import numpy as np

class NumpyCompatibleWrapper:
  def __init__(self, dlpack_buffer):
    self.dlpack_buffer = dlpack_buffer

  def __dlpack__(self):
    return self.dlpack_buffer

a = jnp.arange(10)
b = jdp.to_dlpack(a)
b_wrapper = NumpyCompatibleWrapper(b)

c = np.from_dlpack(b_wrapper)
print(c)
# [0 1 2 3 4 5 6 7 8 9]

I don't think there's any reasonable way for u…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@jakevdp
Comment options

@jakevdp
Comment options

@YuzhiLiu-ai
Comment options

Answer selected by YuzhiLiu-ai
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants