-
Is this legitimate JAX code?
Here I'm wondering whether this is a legitimate usage of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In general yes, functions used in higher-order primitives need to be pure in the sense of not having any side-effects. However, JAX is less strict about functions being pure in the sense of depending on external variables, even when those variables are defined within closures. The JIT-tracing machinery is able to capture |
Beta Was this translation helpful? Give feedback.
In general yes, functions used in higher-order primitives need to be pure in the sense of not having any side-effects. However, JAX is less strict about functions being pure in the sense of depending on external variables, even when those variables are defined within closures. The JIT-tracing machinery is able to capture
loop_body
's dependence on values likea
in your example, and operate correctly. You should see this if you run your function while passing various values fora
(on that point, I think your loop function should be defined asdef loop_body(idx, state)
rather thandef loop_body(state, idx)
).