Skip to content

Commit de9ad6b

Browse files
Merge pull request jax-ml#27157 from mar-muel:improve-random-choice-performance
PiperOrigin-RevId: 737665351
2 parents 031614c + 4a82fe9 commit de9ad6b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

jax/_src/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,8 @@ def choice(key: ArrayLike,
670670
ind = jnp.searchsorted(p_cuml, r).astype(int)
671671
else:
672672
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
673-
g = -gumbel(key, (n_inputs,), dtype=p_arr.dtype) - jnp.log(p_arr)
674-
ind = jnp.argsort(g)[:n_draws]
673+
g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr)
674+
ind = lax.top_k(g, k=n_draws)[1].astype(int)
675675
result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis)
676676

677677
return result.reshape(shape if arr.ndim == 0 else

0 commit comments

Comments
 (0)