-
Given a jitted function @jax.jit
def f(x):
return x**2
df = jax.grad(f)
df(1.) # Hopefully this populates some cache.
print(f._cache_size())
==> 0 # Was expecting 1
f(1.) # Of course this populates cache
print(f._cache_size())
==> 1 # As expected, directly calling f populates cache |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 12 replies
-
Thanks for the question! It's a bit complicated, both for fundamental reasons (caching transformed versions of functions given the way JAX's tracing works) and for transient historical-path-dependent reasons. There are two global jit caches: one for C++ dispatch and the other for dispatching handled by Python. The Python cache is fully general in the sense that it works with all manner of transformations applied to the jitted function, while the C++ cache only works with buffers-in-buffers-out dispatching, and hence not transformation-of-jit cases. (The C++ cache is populated by stealing entries from the Python cache when it can.) The To see the cache grow in grad-of-jit calls, you need to dig pretty carefully: import jax
@jax.jit
def f(x):
return x**2
df = jax.grad(f)
def hilariously_tricky_cache_size():
from jax._src import dispatch
return sum(len(entries) for entries in dispatch._xla_callable.__closure__[1].cell_contents.values())
print(hilariously_tricky_cache_size()) # 0
df(1.) # Hopefully this populates some cache.
print(hilariously_tricky_cache_size()) # 2 The reason there are two entries is that there's a cached jitted forward pass and a cached jitted backward pass. If you'd like a public API for that, open a feature request and let's discuss it! WDYT? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
It's a bit complicated, both for fundamental reasons (caching transformed versions of functions given the way JAX's tracing works) and for transient historical-path-dependent reasons.
There are two global jit caches: one for C++ dispatch and the other for dispatching handled by Python. The Python cache is fully general in the sense that it works with all manner of transformations applied to the jitted function, while the C++ cache only works with buffers-in-buffers-out dispatching, and hence not transformation-of-jit cases. (The C++ cache is populated by stealing entries from the Python cache when it can.)
The
_cache_size()
API (which isn't public AFAIK, is it eve…