Replies: 1 comment 1 reply
-
hii, I think I would also recommend to take a look at torch2jax (I'm not a 100% sure, but it may solve the problem in a convenient way) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm writing a package using JAX, which calls another one based on Autograd. How can I propagate the gradient?
I can think of two options:
The second option seems better, since I won't have to write the chain rule myself, but requires knowing how dual numbers are handled by JAX and Autograd. Can you please help me?
Here is a simple code to develop this feature.
Beta Was this translation helpful? Give feedback.
All reactions