Skip to content
Discussion options

You must be logged in to vote

You can do this using the has_aux argument to jax.grad, which lets you compute the gradient of a function that returns its value along with auxiliary data. For example:

from jax import grad

def f(x, y):
  return y * x, y + 1

grad_f = grad(f, has_aux=True)

print(*grad_f(2.0, 3))
# 3.0, 4

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by sihengz02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants