Get all non-masked element of an array, and pad the rest #10971
-
Hello, I'm struggling on the following simple problem, I have an array of size
What I want is to get all non-nans elements and pad the rest, like this:
I can assume my non-nans elements are sorted, and I could call Any idea how to efficiently perform this operation ? Any help is welcome, thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
I think calling import jax.numpy as jnp
x = jnp.array([0., jnp.nan, jnp.nan, jnp.nan, 3., jnp.nan, 6.])
print(x[jnp.argsort(jnp.isnan(x))])
# [ 0. 3. 6. nan nan nan nan] Since JAX uses a stable sort, it will keep your non-NaN values in their original order. |
Beta Was this translation helpful? Give feedback.
-
If |
Beta Was this translation helpful? Give feedback.
If
k
is a constant, you may usex[jnp.nonzero(~jnp.isnan(x), size=k)]
then pad it ton
elements.Actually this is counting sort.