Replies: 1 comment 6 replies
-
|
I am aware of #1937 but still no idea how to proceed in my case. There are actually a bunch of variables like |
Beta Was this translation helpful? Give feedback.
6 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
jax.grad(f1, has_aux=True)takes 5 seconds compared to 10ms withjax.grad(f2, has_aux=True). In this case I just want to get the gradient ofJ.J_no_gradis for recording some intermediate values. How could I reduce the computation time? Can we do something like:Thanks!
Beta Was this translation helpful? Give feedback.
All reactions