jax.debug.breakpoint() for jax.experimental.ode.odeint? #14307
Unanswered
virajpandya
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to debug an ODE integration code that I ported from pure python (using scipy.integrate.solve_ivp) to JAX. Whereas scipy.integrate.solve_ivp works very well for different ODE solver choices, my JAX port based on jax.experimental.ode.odeint is failing and I'm trying to figure out why. Currently I have put a bunch of jax.debug.print() statements into my integrator function (the one that is fed into the solve_ivp / odeint), but this is tedious and I have not yet found my problem. It's possible I have a bug in one of the other (jitted) functions I've defined that gets called by my main integrator function since the "ys" array returned by odeint has a bunch of NaN's.
Is there a way to use jax.debug.breakpoint() to interactively inspect the values of intermediate variables computed within the integrator function fed to odeint? I tried putting jax.debug.breakpoint() in my integrator function but I'm getting an XLA RuntimeError (see below). My guess is that breakpoint() doesn't make sense for odeint but wanted to double check.
Beta Was this translation helpful? Give feedback.
All reactions