Skip to content
Discussion options

You must be logged in to vote

There may be some way to use a single vmap call, but you can definitely do:

result = jax.vmap(jax.vmap(jax.vmap(MyFunction)))(inputs)

By default, each call to vmap is adding another batch dimension to the left of the shape of the inputs/outputs.

Replies: 2 comments 5 replies

Comment options

You must be logged in to vote
4 replies
@dominicpasquali
Comment options

@YouJiacheng
Comment options

@dominicpasquali
Comment options

@dominicpasquali
Comment options

Answer selected by dominicpasquali
Comment options

You must be logged in to vote
1 reply
@dominicpasquali
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants