Skip to content
Discussion options

You must be logged in to vote

You jit-compile wrapper_inner_loop_body4, which means all arguments passed to it are non-static, including the first element of carry, which is step.

You then pass step as the first argument to inner_loop_body4, which is marked as static, and this leads to an error. To fix that, you should avoid marking this argument as static. The reason this causes issues is because you're using Python's and operator with traced variables; and attempts to eagerly convert outputs to static booleans (a behavior that is impossible to override). For that reason, JAX follows NumPy and uses the elemenwise and operator (&) for this type of operation:

leader_dist = min_pos_delta_if_equal(pos[road], num_agents_p…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@mjhoover1
Comment options

@jakevdp
Comment options

@mjhoover1
Comment options

Answer selected by mjhoover1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants