-
I have an key = jax.random.PRNGKey(42)
a = jnp.arange(15*5*3).reshape(15, 5, 3)
jax.random.choice(key, a, shape=(5, 5, 3)).shape == (5, 5, 3) Is there an efficient way to do this without iterating through Bonus: |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
You just want I'm not sure what it would mean to use an |
Beta Was this translation helpful? Give feedback.
-
For your first question, I believe @davisyoshida is right that you just need to pass split_key = jax.random.split(key, a[0].size).reshape(*a[0].shape, *key.shape)
choice = partial(jax.random.choice, shape=(5,))
choice = jax.vmap(choice, in_axes=(0, 1), out_axes=1)
choice = jax.vmap(choice, in_axes=(1, 2), out_axes=2)
out = choice(split_key, a)
out.shape # (5, 5, 3) The import numpy as np
from functools import partial
p = np.random.rand(*a.shape)
p /= p.sum(0, keepdims=True)
@partial(jax.vmap, in_axes=(1, 2, 2), out_axes=2)
@partial(jax.vmap, in_axes=(0, 1, 1), out_axes=1)
def choice(key, a, p):
return jax.random.choice(key, a, shape=(5,), p=p)
out = choice(split_key, a, p)
out.shape # (5, 5, 3) |
Beta Was this translation helpful? Give feedback.
For your first question, I believe @davisyoshida is right that you just need to pass
shape=(5,)
. An alternative way to do this kind of sampling is by usingvmap
; it might look like this:The
vmap
approach is somewhat more complicated, but it allows you to express exactly the mapping semantics that you wish, rather than relying on the implicit broadcasting/vectorization semantics of the function itself. This be…