-
is it possible to get the grads of the output wrt to intermediate variables? import jax.numpy as jnp def function(w, x): dL_dw = grad(function)(w, x) |
Beta Was this translation helpful? Give feedback.
Answered by
xmax1
Mar 28, 2021
Replies: 1 comment
-
Nevermind, it is answered here #5336 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
xmax1
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nevermind, it is answered here #5336