dlpack confusion #11494
-
import jax
import jax.dlpack as jdp
import jax.numpy as jnp
a = jnp.arange(10)
b = jdp.to_dlpack(a)
hasattr(b, '__dlpack__')result : False I want to transform from jax.array to dlpack and use mpi4py to send and receive it, because of the lack of attribute '_ _ dlpack _ ', it fails. The return of function 'jax.dlpack.to_dlpack' doesn't have the attribute, is it reasonable? if i want to transform dlpack to numpy array. the attribute ' _ dlpack _ _' is necessary for 'numpy.from_dlpack' function |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
|
Hi - this looks like an unfortunate design choice made by numpy: the import jax
import jax.dlpack as jdp
import jax.numpy as jnp
import numpy as np
class NumpyCompatibleWrapper:
def __init__(self, dlpack_buffer):
self.dlpack_buffer = dlpack_buffer
def __dlpack__(self):
return self.dlpack_buffer
a = jnp.arange(10)
b = jdp.to_dlpack(a)
b_wrapper = NumpyCompatibleWrapper(b)
c = np.from_dlpack(b_wrapper)
print(c)
# [0 1 2 3 4 5 6 7 8 9]I don't think there's any reasonable way for us to fix this on the JAX side: we could, of course, add a |
Beta Was this translation helpful? Give feedback.
Hi - this looks like an unfortunate design choice made by numpy: the
numpy.from_dlpackfunction does not accept adlpackobject directly, but rather accepts an object with a__dlpack__()method that returns a dlpack buffer. You can work around this issue as follows:I don't think there's any reasonable way for u…