When a jit function calling another jit function multiple times, does the latter need to be compiled each time? #18958
-
Suppose we've defined a jit function A: @jax.jit
def A():
... As we know, when a non-jit function B calls A multiple times: def B():
A()
A() starting from the second invocation of A(), A() will benefit from the acceleration of jit. But if a jit function B calls A multiple times: @jax.jit
def B():
A()
A() when first time calling B(), I found that each A() would be compiled, which take a long time. Thank you so much for any help! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 1 reply
-
Thanks for the question! When I try the example you give, import jax
@jax.jit
def A():
print('compiling a')
return
@jax.jit
def B():
A()
A()
B()
# compiling a
print(A._cache_size())
# 1 It is possible that your simple example leaves out relevant details: for examle if the arguments to Can you try to put together a minimal reproducible example of the behavior you're seeing, including how you're determining that |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
Yes, updating jax to v0.4.7 solved this problem. Thank you once again for your help! |
Beta Was this translation helpful? Give feedback.
Yes, it's possible that the compilation cache worked differently when jax v0.3.14 was released (June 2022) so its not surprising that you're seeing different results than you would with a more recent JAX version.
I'd suggest updating JAX to a more recent release; we're not going to be able to help you debug the details of the compilation cache behavior from such an old JAX version, as I suspect you're hitting an issue that has already been debugged and fixed in a more recent release.