Skip to content
Discussion options

You must be logged in to vote

Thanks for the question!

The reason for this error is that a parameter that is vmapped over is not static, so you cannot pass such a parameter to a function that is marked as requiring a static argument.

The fix is to not mark the argument static if you want to vmap over it.

The problem here, though, is that you're passing this argument (s) to another function that requires a static argument: index slicing. In JAX the size of an array must be static, so the s in x[:s] cannot be a traced value.

The solution here, if you want to sum over dynamically-sized slices, is to write your code in a way that does not require generating dynamically-sized arrays. Here is a workaround:

import jax
import j…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@guglielmogattiglio
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@guglielmogattiglio
Comment options

Answer selected by guglielmogattiglio
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants