Efficiency of jit and vmap #21505
Replies: 1 comment 4 replies
-
Same outcome. In either case, the batched function is compiled. For example: >>> import jax; import jax.numpy as jnp
>>> def f(m, v): return m @ v
...
>>> m, vs = jnp.ones((3, 4)), jnp.ones((7, 4))
>>> jax.make_jaxpr(jax.jit(jax.vmap(f, in_axes=(None, 0))))(m, vs)
{ lambda ; a:f32[3,4] b:f32[7,4]. let
c:f32[7,3] = pjit[
name=f
jaxpr={ lambda ; d:f32[3,4] e:f32[7,4]. let
f:f32[3,7] = dot_general[
dimension_numbers=(([1], [1]), ([], []))
preferred_element_type=float32
] d e
g:f32[7,3] = transpose[permutation=(1, 0)] f
in (g,) }
] a b
in (c,) }
>>> jax.make_jaxpr(jax.vmap(jax.jit(f), in_axes=(None, 0)))(m, vs)
{ lambda ; a:f32[3,4] b:f32[7,4]. let
c:f32[3,7] = pjit[
name=f
jaxpr={ lambda ; d:f32[3,4] e:f32[7,4]. let
f:f32[3,7] = dot_general[
dimension_numbers=(([1], [1]), ([], []))
preferred_element_type=float32
] d e
in (f,) }
] a b
g:f32[7,3] = transpose[permutation=(1, 0)] c
in (g,) }
The batching transformation would happen every time, i.e. for different batch sizes but also for before-seen batch sizes. There are no caching guarantees here, so this could change. But the batching transformation is also efficient and inline, and there isn't much overhead to it over standard evaluation: the system is essentially bookkeeping batch dimensions and calling into simple batching rules for evaluation. If you'll be compiling anyway, then |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
As demonstrated by the doc,
jax.jit
andjax.vmap
can be composed to produce a batched-compiled version of the target function. What I wanna know is whether there is a difference betweenjax.jit(jax.vmap(func))
andjax.vmap(jax.jit(func))
behind the scene. It seems to me that the former should generate more efficient code since the batchedfunc
is compiled altogether.A second question is about vmap alone. Does vmap produce a new version of batched function each time it encounters a different batch size? For example, we have
jax.vmap(jax.jit(func))(x)
, will there be extra overhead if the input x is shaped (32, D) for the first time and (64, D) for the second? Will there be any re-mapping or re-compilation taking place at the second time?Beta Was this translation helpful? Give feedback.
All reactions