-
Hello, I use jax.experimental.ode odeint function in my work. You implemented the custom_vjp functionnality for this function (which derivates the ODE system and solves it). |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Thanks for the question! I wrote up my own derivation for differentiation of ODEs here: https://implicit-layers-tutorial.org/implicit_functions/, specifically this subsection. I like that derivation because it's simple and mechanical: no steps require creative leaps and it doesn't involve importing extraneous machinery (like Lagrange multipliers). Also unlike any other derivation I've seen it allows decomposition of reverse-mode differentiation into linearization, partial evaluation, and transposition; that makes it a natural fit for JAX's autodiff system, though our implementation hasn't yet caught up. Some other resources:
There are probably lots of other resources with which I'm not familiar. WDYT? |
Beta Was this translation helpful? Give feedback.
-
Shameless self promotion. By far the simplest derivation I know of is that of Section 5.1.2.1 of On Neural Differential Equations, which gives a derivation in just 5 lines of elementary calculus. If you're curious to know more: Theorem 5.2 on the previous page gives the statement being proved; Appendic C.3.1 gives an ever-so-slightly-longer derivation with full mathematical rigor; finally Diffrax is a JAX-based faster and more featureful alternative to |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
I wrote up my own derivation for differentiation of ODEs here: https://implicit-layers-tutorial.org/implicit_functions/, specifically this subsection.
I like that derivation because it's simple and mechanical: no steps require creative leaps and it doesn't involve importing extraneous machinery (like Lagrange multipliers). Also unlike any other derivation I've seen it allows decomposition of reverse-mode differentiation into linearization, partial evaluation, and transposition; that makes it a natural fit for JAX's autodiff system, though our implementation hasn't yet caught up.
Some other resources: