Skip to content
Discussion options

You must be logged in to vote

I think you have the wrong mental model of what vmap is doing. vmap is about logical batching, and does not imply anything about sequential computation of the batches. In the simplest cases, using vmap is identical to using standard numpy-style arguments in functions. Here is a quick example showing this:

from jax import vmap, make_jaxpr
import jax.numpy as jnp

x = jnp.ones((3, 4))

make_jaxpr(vmap(jnp.sum))(x)
# { lambda  ; a.
#   let b = reduce_sum[ axes=(1,) ] a
#   in (b,) }

make_jaxpr(lambda x: jnp.sum(x, axis=-1))(x)
# { lambda  ; a.
#   let b = reduce_sum[ axes=(1,) ] a
#   in (b,) }

Calling vmap on sum for 2D input is identical to calling an unmapped sum with an axis argument: t…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by adam-hartshorne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants