-
In a code that I'm writing, I need to permute a matrix along the rows independently from one column to another. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I think the reason For what it's worth, you can recover the independent shuffling behavior by using import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(42)
x = jnp.broadcast_to(jnp.arange(3), (4, 3))
print(random.shuffle(key, x, axis=1))
# [[0 2 1]
# [1 0 2]
# [1 2 0]
# [2 1 0]]
print(jax.vmap(random.permutation)(random.split(key, x.shape[0]), x))
# [[1 2 0]
# [0 2 1]
# [2 1 0]
# [1 0 2]] ... though it's admittedly less convenient. Perhaps the current |
Beta Was this translation helpful? Give feedback.
I think the reason
jax.random.shuffle
is deprecated is because it is potentially confusing:numpy.random.shuffle
permutes contents in-place, and does not permute independently along each row, so having a function of the same name injax.random
with markedly different behavior might be confusing.For what it's worth, you can recover the independent shuffling behavior by using
permutation
withvmap
; for example: