Clean masked median? #21811
-
I'm having some trouble coming up with a clean way of computing a median over a (dynamic) mask in a jitted function. Means are simple because of the The I'm wondering if there's a cleaner way of doing this with JAX that I'm not thinking of? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Using import jax.numpy as jnp
import numpy as np
import jax
@jax.jit
def median_where(x, where):
assert x.ndim == 1
assert x.shape == where.shape
assert where.dtype == bool
N = jnp.sum(where)
x = jnp.where(where, x, x.max())
return jnp.quantile(x, 0.5 * N / len(x), method='midpoint')
np.random.seed(42)
x = jnp.array(np.random.randn(100))
where = (x > 0)
print(jnp.median(x[where]))
# 0.52791375
print(median_where(x, where))
# 0.52791375 |
Beta Was this translation helpful? Give feedback.
Using
nanmedian
seems the cleanest, but if you want to avoid introducing NaNs you could do something similar to the implementation ofnanquantile
. It might look more-or-less like this: