Skip to content
Discussion options

You must be logged in to vote

Yes, that seems like a concise and performant way to implement this functionality in JAX. You might avoid the reshape by using lax.index_in_dim instead:

import jax.numpy as jnp
from jax import lax

def jax_unstack(x, axis=0):
  return [lax.index_in_dim(x, i, axis, keepdims=False) for i in range(x.shape[axis])]

x = jnp.arange(9).reshape(3, 3)
jax_unstack(x, axis=0)
# [DeviceArray([0, 1, 2], dtype=int32),
#  DeviceArray([3, 4, 5], dtype=int32),
#  DeviceArray([6, 7, 8], dtype=int32)]

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@conceptofmind
Comment options

Answer selected by conceptofmind
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants