How to convert an xla buffer to a JAX DeviceArray? #20434
Unanswered
jing-alice
asked this question in
Q&A
Replies: 1 comment
-
Can you please: |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
when jax==0.3.25,
def xla_buffer_to_jax_tensor(xla_buf): aval = ShapedArray(xla_buf.shape, xla_buf.dtype) return _DeviceArray(aval, xla_buf.device(), xla_buf)
it can work.
when jax==0.4.15,
def xla_buffer_to_jax_tensor(xla_buf): aval = ShapedArray(xla_buf.shape, xla_buf.dtype) print("xla_devices: ", xla_buf.device()) return jax.device_put(xla_buf, xla_buf.device())
however, it doesn't work, raise ValueError: Received incompatible devices for jitted computation. Got argument of jax_tensor_set with shape float32[10,4,4,256] and device ids [0] on platform GPU and argument of jax_tensor_set with shape float32[10,4,4,128] and device ids [0] on platform GPU. But the GPU ids are same.
Beta Was this translation helpful? Give feedback.
All reactions