Skip to content
Discussion options

You must be logged in to vote

likely it's because on the first iteration JAX is compiling parts of the computation, while the subsequent times it's caching that compilation

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by pvasired
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants