How to serialize DeviceArray with pybind11 #11335
-
Dear communities, I am trying to combine c++ with jax. In my case, xla_client.ops.CustomCallWithLayout is not a good option. Because what i am trying to do is to resue cuda code rather than speeding up. Morever, some part of the cuda code is hard to be rewritten as python. When i tried to pass DeviceArray as a pybind buffer, it said 'only CPU array can be converted to buffer'. It is ok to convert DeviceArray to pybind array_t. However, i still do not know how to return a DeviceArray. Morever, i am not sure whether the conversion involves frequent host-device memcpy. I noticed in tensorflow/compiler/xla exists something called PyBuffer::object. However creating jnp array is still too complicated. Best regards! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Why |
Beta Was this translation helpful? Give feedback.
-
But if you don't need to make your cuda code compatible with |
Beta Was this translation helpful? Give feedback.
But if you don't need to make your cuda code compatible with
jax.jit
, usejax.dlpack.to_dlpack
andjax.dlpack.from_dlpack
might be the easiest way to achieve your goal.