Skip to content
Discussion options

You must be logged in to vote

Thanks for the question!

It's a bit complicated, both for fundamental reasons (caching transformed versions of functions given the way JAX's tracing works) and for transient historical-path-dependent reasons.

There are two global jit caches: one for C++ dispatch and the other for dispatching handled by Python. The Python cache is fully general in the sense that it works with all manner of transformations applied to the jitted function, while the C++ cache only works with buffers-in-buffers-out dispatching, and hence not transformation-of-jit cases. (The C++ cache is populated by stealing entries from the Python cache when it can.)

The _cache_size() API (which isn't public AFAIK, is it eve…

Replies: 1 comment 12 replies

Comment options

You must be logged in to vote
12 replies
@VolodyaCO
Comment options

@yashk2810
Comment options

@VolodyaCO
Comment options

@patrick-kidger
Comment options

@VolodyaCO
Comment options

Answer selected by langmore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants