Skip to content
Discussion options

You must be logged in to vote

I think that's what static_argnums is for in the jax.jit API: marking certain arguments (my rule of thumb: not JAX-array-coercible ones like strings, callables etc.) constant ("static") for a JIT-compiled function call. These will then need to stay the same for as long as possible to avoid cache misses, which trigger full recompiles of the function.

So, in your example, if A is the n-th argument to the jax.scipy.sparse.linalg.gmres function, calling jax.jit(gmres, static_argnums=(n,)) should fix this.

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
2 replies
@peterdsharpe
Comment options

@YouJiacheng
Comment options

Answer selected by peterdsharpe
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants