Skip to content
Discussion options

You must be logged in to vote

For your first question, I believe @davisyoshida is right that you just need to pass shape=(5,). An alternative way to do this kind of sampling is by using vmap; it might look like this:

split_key = jax.random.split(key, a[0].size).reshape(*a[0].shape, *key.shape)
choice = partial(jax.random.choice, shape=(5,))
choice = jax.vmap(choice, in_axes=(0, 1), out_axes=1)
choice = jax.vmap(choice, in_axes=(1, 2), out_axes=2)
out = choice(split_key, a)
out.shape  # (5, 5, 3)

The vmap approach is somewhat more complicated, but it allows you to express exactly the mapping semantics that you wish, rather than relying on the implicit broadcasting/vectorization semantics of the function itself. This be…

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
2 replies
@Qottmann
Comment options

@davisyoshida
Comment options

Comment options

You must be logged in to vote
1 reply
@Qottmann
Comment options

Answer selected by Qottmann
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants