How to improve speed and efficiency in ODE workloads with JAX #6464
-
Hello Jax community! I started using the library a few weeks ago and I'm really enjoying it so far. :--) I want to contribute to the project in the near future, which is why I am practicing and trying to get some hands-on experience with it now by porting a personal numpy project of mine about ODEs to Jax. But after porting the first two solvers today (Euler + Heun methods, so very vanilla ones) and running some benchmarks, I noticed a big drop in performance compared to numpy after adopting Jax. Right now, my port is very unsophisticated - so far, I really only switched Some things which I suspect are killing performance right now:
So obviously, I am still doing a lot of things wrong. Therefore, I was hoping to have a discussion about some best practices for Jax (and maybe ODE solving with Jax, specifically) that I can use to write more efficient Jax code in the future. All the best, PS: I am aware of the odeint example in the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Awesome to hear you're enjoying exploring jax! Jax is really designed to be used under jit as far as possible. The un-jitted jax equivalents of numpy functions are generally slower than their numpy counterparts, since numpy is optimized that way. Plus, you need to jit to use any accelerators like a GPU, TPU etc. Secondly, you'll want to avoid python loops and control flow with jax. Raw python loops are slow if they're not under a jit, and if they are in a jitted function then this can blow up the compilation time since python loops are statically unrolled by the compiler, which generally results in superlinear compilation time with the length of the loop. Replacing for loops with Using Could be worth checking this repository out, which has quite a few ODE solvers implemented in JAX |
Beta Was this translation helpful? Give feedback.
Awesome to hear you're enjoying exploring jax!
I should preface this by saying that I'm not an expert in ODEs, but as far as general tips with using jax goes:
Jax is really designed to be used under jit as far as possible. The un-jitted jax equivalents of numpy functions are generally slower than their numpy counterparts, since numpy is optimized that way. Plus, you need to jit to use any accelerators like a GPU, TPU etc.
As a rule of thumb, try to jit the top-most function of your algorithm. If you jit at the top level, the compiler can see everything underneath and make nice optimizations. When I first started using jit I had a lot of issues, but using
static_argnums
can help make this …