PyBaMM jax based BDF solver: completly differentiable? #3453
Unanswered
mayor-slash
asked this question in
Q&A
Replies: 1 comment
-
Hi @mayor-slash, the method you are looking for is |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
This question is only related to the jax based BDF solver in PyBaMM:
I have an ODE system where$\frac{dY}{dt} = f(Y,P)$ with P being multiple parameters and I compare the resulting solution $Y(t)$ to compute a loss value L. My goal is to calculate the gradient $\frac{dL}{dP}$ to optimize the parameters P to minimize L.
The Jax based module Diffrax is perfect for this job, as it lets me differentiate through the solving process of the ODE. My function$f(Y,P)$ however is to complex, so the available implicit solvers can't solve it properly. Using the scipy solve_ivp function with the BDF method I can solve the system properly but it doesn't allow the backward differentiation trough the solver to compute my gradients. Since the solver in PyBaMM is also jax based, I tried it on my system. As expected the forward solve is working properly and produces the same result as the scipy function (they are based on the some paper). However I can't compute the gradients with this solver either.
My guess is that somewhere in the implementation there is "un-pure" jax code which prohibits the autodifferentiation. My question is, was the possibility of differentiating through the solver ever desired in PyBaMM? Is it maybe even used but I messed something up in the implementation?
Beta Was this translation helpful? Give feedback.
All reactions