Looking for a fast implementation of many trajectories using odeint #15404
Unanswered
alfercorral
asked this question in
Q&A
Replies: 1 comment 3 replies
-
Also, I looked at the suggested version of Diffrax for similar code in a jitted, vmapped manner: term = ODETerm(force) def sol(t_0, y0): sol = jax.jit(jax.vmap(sol, in_axes = (0,0))) Still this solution is not outperforming normal odeint solution. Any suggestion from developer @patrick-kidger ? |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am using jax.experimental.odeint for solving a coupled ODE with many possible trajectories (including different initial conditions for space and time) in a MonteCarlo fashion, that is, generating millions of random events within a probability distribution and solving many of them to see how the ending distribution looks like. For every event, I will find different initial conditions for space configuration y0: (6,) and time t0: (1,), but all trajectories end in the same final time t_f.
I found that odeint is very fast when working with same time expansion, that is, if all events started in the same t0. In that case, no further coding needs to be done, as odeint can take care of many initial positions in space configuration. However, when trying to work with many different time expansions, arising from:
t = [ jnp.linspace(t_i,t_f,n_steps) for t_i in t_0 ]
t = jnp.array(t)
I cannot find an optimised version of the code that allows for solutions as fast as in the same initial time expansion. The working time of the program goes from seconds for thousands of trajectories to minutes for hundreds of them with different initial conditions.
I have also tried implementing a jitted vmapped version of the code:
def odeint(y0,t):
return jax.vmap(ode.odeint, in_axes = (None,0,0))(force,y0,t)
odeint = jax.jit(odeint)
Or even a pmapped version of it:
def ode_int(y0,t):
return ode.odeint(force,y0,t)
inte = jax.jit(jax.pmap(ode_int, in_axes = (0, 0)))
But I find that for different reasons, these are not great solutions.
Is there any nice way of having thousands of trajectories being solved within the reach of odeint or maybe in Diffrax that I don't know of?
Beta Was this translation helpful? Give feedback.
All reactions