Skip to content
Discussion options

You must be logged in to vote

The thing to keep in mind is that in JAX's JIT, array shapes must be static. That means that the shape of an output array cannot depend on a traced value (for more on traced values, see How To Think In JAX. In your function, total_pad is a traced value, and the shape of the output array depends on total_pad, so it cannot be JIT-compiled.

The only way around this is to ensure that the shape of the output arrays does not depend on traced values; one way to do this is to mark the argument as static, using @partial(jax.jit, static_argnums=1) as you have in the comment above your function, and further to not perform any jax operations (like jnp.where) on the value you would like to be static; …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Joy-Lunkad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants