Skip to content
Discussion options

You must be logged in to vote

Using nanmedian seems the cleanest, but if you want to avoid introducing NaNs you could do something similar to the implementation of nanquantile. It might look more-or-less like this:

import jax.numpy as jnp
import numpy as np
import jax

@jax.jit
def median_where(x, where):
  assert x.ndim == 1
  assert x.shape == where.shape
  assert where.dtype == bool
  N = jnp.sum(where)
  x = jnp.where(where, x, x.max())
  return jnp.quantile(x, 0.5 * N / len(x), method='midpoint')

np.random.seed(42)
x = jnp.array(np.random.randn(100))
where = (x > 0)
print(jnp.median(x[where]))
# 0.52791375
print(median_where(x, where))
# 0.52791375

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jeffgortmaker
Comment options

Answer selected by jeffgortmaker
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants