Skip to content
Discussion options

You must be logged in to vote

Yes, in general JAX does not (yet) support data-dependent shapes inside a jit. Here, num necessarily depends on the matrix. Outside a jit it works fine.

The usual workaround in JAX is to return a padded array, i.e., return Q unsliced together with num that tells you how much of the array contains useful data. This requires an API change; sometimes we do that by adding an optional keyword argument that states if you want to exactly mimic the NumPy/SciPy behavior or return the (Q, num) pair in a jit-compatible way.

We have plan for more first-class support for dynamic shapes inside jit but nothing ready at this time.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@shailesh1729
Comment options

Answer selected by shailesh1729
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants