How does jax caching - recompilation work? #8458
-
I have a def a(b, c):
do something with a and c and then i have: parallel_a = jax.pmap(vmap(a, in_axes=(0, None)), in_axes=(0, None)) I call for i in range(10):
as = np.random.random((100, 100))
b = np.random.random((100, 10))
z = parallel_a(as, b)
print(z) Following this code, the first call to for i in range(1, 11):
as = np.random.random((100, 10*i))
b = np.random((10*i, 10))
z = parallel_a(as, b)
print(z) here, all of the 10 calls will take 6 seconds to run. I'm using My questions are:
Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
As a more general question, is there any method to check if on a run, we're gonna use a compiled version or we're gonna compile and then run? |
Beta Was this translation helpful? Give feedback.
-
Thanks for the question! The full answer is a bit more involved, but in general recompilation is triggered by changes in the static properties of input arrays. So if a JIT-compiled function sees an input with the same shape and dtype as one it has seen before, the cached code is used. If the function sees an input with a shape and/or dtype it has not seen before, this triggers a re-compilation. With this in mind, the recompilations you're seeing are exactly as expected. Why is this? Mainly, it's due to the way XLA compilation works: programs are compiled for inputs of a particular shape. There has been some amount of work on supporting dynamic shapes in XLA & in JAX, but it hasn't quite made it to maturity. |
Beta Was this translation helpful? Give feedback.
Thanks for the question! The full answer is a bit more involved, but in general recompilation is triggered by changes in the static properties of input arrays. So if a JIT-compiled function sees an input with the same shape and dtype as one it has seen before, the cached code is used. If the function sees an input with a shape and/or dtype it has not seen before, this triggers a re-compilation. With this in mind, the recompilations you're seeing are exactly as expected.
Why is this? Mainly, it's due to the way XLA compilation works: programs are compiled for inputs of a particular shape. There has been some amount of work on supporting dynamic shapes in XLA & in JAX, but it hasn't quite m…