Unstack operation in Jax #11028
Answered
by
jakevdp
conceptofmind
asked this question in
Q&A
Unstack operation in Jax
#11028
-
Hello, In PyTorch, we have For example: x1, x2 = x.unbind(dim=-2) And, in Tensorflow, we have For example: x1, x2 = x.unstack(dim=-2) Would the equivalent in JAX be? def jax_unstack(x, axis=0):
return jnp.moveaxis(x, axis, 0)
x1, x2 = jax_unstack(x, axis =-2) Any help would be greatly appreciated. Thank you, Enrico |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jun 8, 2022
Replies: 1 comment 1 reply
-
Yes, that seems like a concise and performant way to implement this functionality in JAX. You might avoid the reshape by using import jax.numpy as jnp
from jax import lax
def jax_unstack(x, axis=0):
return [lax.index_in_dim(x, i, axis, keepdims=False) for i in range(x.shape[axis])]
x = jnp.arange(9).reshape(3, 3)
jax_unstack(x, axis=0)
# [DeviceArray([0, 1, 2], dtype=int32),
# DeviceArray([3, 4, 5], dtype=int32),
# DeviceArray([6, 7, 8], dtype=int32)] |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
conceptofmind
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes, that seems like a concise and performant way to implement this functionality in JAX. You might avoid the reshape by using
lax.index_in_dim
instead: