Skip to content
Discussion options

You must be logged in to vote

The issue is that JAX JIT cannot create arrays that are dynamically shaped, and your function creates an array that is dynamically shaped (i.e. the value of 3*a + 2*b is not known at compile time, because a and b are not marked as static).

To fix this, you should do what you did in the second function, and mark as static any parameters that determine the shape of an array you create within a jit-compiled function. For example, this works:

print(jit(func, static_argnums=[0, 1])(2,3))

You'll find more background and explanation of this topic in the How To Think In JAX doc.

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@insuhan
Comment options

Answer selected by insuhan
Comment options

You must be logged in to vote
1 reply
@soraros
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants