Skip to content
Discussion options

You must be logged in to vote

Hi Callum,

Thanks for the clear question!

The short answer is that you can still use jax.vmap here, if you want to vmap the apply_fn. You only really need hk.vmap if you need the vmap to be inside of the hk.transform. With jax.vmap you have complete control over whether the RNG state is shared or split across the vmap dimension.

If you rewrite your example of

batched_dummy_output = vmap(network.apply, in_axes=(None,None,0))(params, key, batched_dummy_input)

to

keys = jax.random.split(key, batched_dummy_input.shape[0])
batched_dummy_output = vmap(network.apply, in_axes=(None, 0, 0))(params, keys, batched_dummy_input)

this should do what you want it to (note that we are splitting the key …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@callumtilbury
Comment options

Answer selected by callumtilbury
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants