Skip to content
Discussion options

You must be logged in to vote

Awesome to hear you're enjoying exploring jax!
I should preface this by saying that I'm not an expert in ODEs, but as far as general tips with using jax goes:

Jax is really designed to be used under jit as far as possible. The un-jitted jax equivalents of numpy functions are generally slower than their numpy counterparts, since numpy is optimized that way. Plus, you need to jit to use any accelerators like a GPU, TPU etc.
As a rule of thumb, try to jit the top-most function of your algorithm. If you jit at the top level, the compiler can see everything underneath and make nice optimizations. When I first started using jit I had a lot of issues, but using static_argnums can help make this …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@nicholasjng
Comment options

Answer selected by nicholasjng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants