grad + vmap + odeint AssertionError #8782
-
Hi, I've seen a bunch of discussions and issues surrounding this so I apologize if i'm re-raising something that has already been addressed elsewhere. I don't understand the internals of JAX enough to understand if this is some version of an issue that's already raised, though a lot of what I'm seeing is in already closed issues so I assume are solved and hence this is different. When I run the following code on the latest release versions of jax/jaxlib (a self-contained version of my actual code):
I get the error:
I've definitely reverse-mode differentiated this code in the past with success, though it was some time ago so would have been on a much older version of jax/jaxlib. Based on other similar issues I've seen it seems like this has something to do with the interaction of reverse-mode autodiff, vmap, and the control flow used in Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for raising this. It certainly looks like a bug! For that reason I made a corresponding bug issue in #8783. (Also I wanted to try out a button I just noticed on the GitHub UI...) Let's follow up there! |
Beta Was this translation helpful? Give feedback.
Thanks for raising this. It certainly looks like a bug! For that reason I made a corresponding bug issue in #8783. (Also I wanted to try out a button I just noticed on the GitHub UI...)
Let's follow up there!