Skip to content
Discussion options

You must be logged in to vote

Is something like this what you have in mind?

def new_grad(f):
  def f_grad(x, y):
    return jax.hessian(f, argnums=0)(x, y) + jax.hessian(f, argnums=1)(x, y)
  return f_grad

f = lambda x, y: x * y
result = new_grad(f)(1.0, 2.0)

new_grad(lambda x, y: x ** y)(1.0, 2.0)

Replies: 3 comments 3 replies

Comment options

You must be logged in to vote
2 replies
@exenGT
Comment options

@exenGT
Comment options

Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Answer selected by exenGT
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants