JIT Caching and Function Composition #10713
-
Hello, I have a question regarding the behavior of jit caching and function composition. Given an example as follows: def f(a, b):
# two functions of fixed dimension output
# regardless of input shape (e.g., jnp.sum)
out1 = f1(a)
out2 = f2(b)
return out1 + out2
# two different shapes
arr1 = jnp.arange(5)
arr2 = jnp.arange(10)
fjit = jax.jit(f)
_ = fjit(arr1, arr2) # ex 1
_ = fjit(arr1, arr1) # ex 2
_ = fjit(arr2, arr2) # ex 3 In this example, f1 (ex 1 and 2) and f2 (ex 1 and 3) are being passed inputs of the same shape across function calls despite f seeing different shapes. Does jax make any guarantees about (or even ways to inspect) whether a previously compiled function is being reused? With inlining and fusion, there's almost certainly no guarantee that f1 and f2 would always be reused. However, would they ever be reused or does the cache only store elements at the level of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Cache only store elements at the level of |
Beta Was this translation helpful? Give feedback.
Cache only store elements at the level of
f
.If you
jit
f1
andf2
, things might be different in future, but not now(jax-0.3.12).