vmap memory usage - is this expected behaviour? #6194
-
I have tried to batch my training data using a vmap based approach, rather than manually selecting the batch size. However if I do this, in terms of memory usage, it seems to act like the data isn't batched at all and in my example case then causes out of memory exit.
My presumption was that vmap would vectorize in such a manner to ensure that only available memory was filled. Is this an incorrect assumption or is there something going wrong with vmap? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I think you have the wrong mental model of what from jax import vmap, make_jaxpr
import jax.numpy as jnp
x = jnp.ones((3, 4))
make_jaxpr(vmap(jnp.sum))(x)
# { lambda ; a.
# let b = reduce_sum[ axes=(1,) ] a
# in (b,) }
make_jaxpr(lambda x: jnp.sum(x, axis=-1))(x)
# { lambda ; a.
# let b = reduce_sum[ axes=(1,) ] a
# in (b,) } Calling The closest thing JAX has to what you're after is |
Beta Was this translation helpful? Give feedback.
I think you have the wrong mental model of what
vmap
is doing.vmap
is about logical batching, and does not imply anything about sequential computation of the batches. In the simplest cases, usingvmap
is identical to using standard numpy-style arguments in functions. Here is a quick example showing this:Calling
vmap
onsum
for 2D input is identical to calling an unmappedsum
with an axis argument: t…