Skip to content
Discussion options

You must be logged in to vote

JAX's JIT does just-in-time compilation of functions based on their inputs. Compiling the function for one set of inputs does not preclude later compiling the function again for a second set.

The first time you call a function with inputs of a given shape & dtype, the function is compiled, and then executed. The second time you call the function with inputs of matching shapes/dtypes, the cached compilation result is used (i.e. no new compilation is necessary). If you call the function with a different set of shapes & dtypes, the function is compiled again and the result of the new compilation is cached.

This section of the docs might be good background on the mechanics of JIT: https://jax…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@GoktugGuvercin
Comment options

Answer selected by GoktugGuvercin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants