jax.experimental.ode.odeint #13345
-
Hello, I am quite new to using JAX. I tried using the odeint from the experimental package and realized that it returns the tracer instead of the device arrays. Is it due to the jit compilation or the custom_vjp or both? I tried using disable_jit to get back regular jax numpy arrays but could not achieve that. Is there any way to use jax odeint implementation but not get the results as a tracer? Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Do you mean that tracers are being passed into your vector field? This is expected, and is due to the JIT compilation. If you're trying to call a non-JAX operation then try using Side note, you may wish to try using Diffrax, which is much more featureful (and typically a bit faster) than |
Beta Was this translation helpful? Give feedback.
Do you mean that tracers are being passed into your vector field? This is expected, and is due to the JIT compilation.
If you're trying to call a non-JAX operation then try using
jax.pure_callback
inside your vector field.Side note, you may wish to try using Diffrax, which is much more featureful (and typically a bit faster) than
odeint
.