We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 031614c + 4a82fe9 commit de9ad6bCopy full SHA for de9ad6b
jax/_src/random.py
@@ -670,8 +670,8 @@ def choice(key: ArrayLike,
670
ind = jnp.searchsorted(p_cuml, r).astype(int)
671
else:
672
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
673
- g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
674
- ind = jnp.argsort(g)[:n_draws]
+ g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
+ ind = lax.top_k(g, k=n_draws)[1].astype(int)
675
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
676
677
return result.reshape(shape if arr.ndim == 0 else
0 commit comments