Understanding vmap #21428
-
Lets say a vmapped function can be defined as |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
No, JAX will not in general make copies of unmapped arrays, although its possible for particular primitives to do things this way in their batching rule. If you want to be sure, you can use import jax
import jax.numpy as jnp
mapped_array = jnp.zeros((100000, 3))
array1 = jnp.arange(3)
array2 = jnp.ones(3)
def func(arr, arr1, arr2):
return arr + arr1 + arr2
jax.make_jaxpr(jax.vmap(func, in_axes=(0, None, None)))(mapped_array, array1, array2)
You can see that the unmapped input arrays (called |
Beta Was this translation helpful? Give feedback.
No, JAX will not in general make copies of unmapped arrays, although its possible for particular primitives to do things this way in their batching rule. If you want to be sure, you can use
jax.make_jaxpr
and inspect the size of the arrays created by the transformed function. For example: