Skip to content
Discussion options

You must be logged in to vote

In general yes, functions used in higher-order primitives need to be pure in the sense of not having any side-effects. However, JAX is less strict about functions being pure in the sense of depending on external variables, even when those variables are defined within closures. The JIT-tracing machinery is able to capture loop_body's dependence on values like a in your example, and operate correctly. You should see this if you run your function while passing various values for a (on that point, I think your loop function should be defined as def loop_body(idx, state) rather than def loop_body(state, idx)).

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by hr0nix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants