Skip to content
Discussion options

You must be logged in to vote

Thanks for the question!

When you run

dim_super = jnp.sum(dimX_submat[0, :])

within a JIT-compiled context, the result is a traced value regardless of whether the input is static. The sizes of matrices cannot depend on traced values, so when you run jnp.zeros((dim_super, dim_super)) you get the tracer conversion error.

I haven't looked closely at your full algorithm, but in order to use it within JAX transforms like JIT, you will have to write it in such a way that array sizes are static: one useful pattern is to use standard numpy ops when dealing with static values.

For more information on this, the How to think in JAX document is a useful read.

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by jakevdp
Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

@srijaniiserprinceton
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants