Skip to content
Discussion options

You must be logged in to vote

The issue here is that you're attempting to create a dynamically-shaped array (see JAX sharp bits: dynamic shapes).

The semantics of jnp.nonzero are that it returns an array whose length is the number of nonzero elements in its argument; this length cannot be known at compile time, and so using it in this way is not compatible with transformations like jit and vmap which require all arrays to be statically-shaped.

The solution would be to re-express your algorithm in a way that uses statically-shaped arrays. For example, if the number of nonzero entries is known statically, you can pass that to the size argument of jnp.nonzero. I'm not sure whether this is relevant in your case: I'm not f…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by hsharsh
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants