Skip to content
Discussion options

You must be logged in to vote

JAX doesn't support dynamic control flow. (specifically, Tracer doesn't support __index__ method used by range(n) or arr[i](for tuple and list) etc.)
There are 2 options:

  1. Mark a as static argument. And use numpy instead of jnp for N computation, or use jax.ensure_compile_time_eval(not recommended).
  2. Use jax.lax.while_loop instead of for i in range(N). see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@lanmao998
Comment options

@tomsturges
Comment options

@jakevdp
Comment options

Answer selected by lanmao998
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants