-
I'm so sorry to ask a question which I think the answer should be easy, but I can't figure it out. def forward(params, x, data):
......
return loss, x But if I do this, I can't calculate the gradient of the function |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You can do this using the from jax import grad
def f(x, y):
return y * x, y + 1
grad_f = grad(f, has_aux=True)
print(*grad_f(2.0, 3))
# 3.0, 4 |
Beta Was this translation helpful? Give feedback.
You can do this using the
has_aux
argument tojax.grad
, which lets you compute the gradient of a function that returns its value along with auxiliary data. For example: