Skip to content
Discussion options

You must be logged in to vote

I think the reason jax.random.shuffle is deprecated is because it is potentially confusing: numpy.random.shuffle permutes contents in-place, and does not permute independently along each row, so having a function of the same name in jax.random with markedly different behavior might be confusing.

For what it's worth, you can recover the independent shuffling behavior by using permutation with vmap; for example:

import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(42)
x = jnp.broadcast_to(jnp.arange(3), (4, 3))

print(random.shuffle(key, x, axis=1))
# [[0 2 1]
#  [1 0 2]
#  [1 2 0]
#  [2 1 0]]
print(jax.vmap(random.permutation)(random.split(key, x.shape[0]), x))
#…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@maurorigo
Comment options

@jakevdp
Comment options

Answer selected by maurorigo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants