Skip to content
Discussion options

You must be logged in to vote

Great questions!

For your first question, the answer is yes, basically by using autodiff to get access to the function (u, v) -> u'Hv. With that function you can extract any element of the Hessian by feeding in standard basis vectors for u and v.

More concretely:

import jax
import jax.numpy as jnp

def hessian_bilinear_form(f, x, u, v):
  # forward-over-reverse autodiff
  f_jvp = lambda x: jax.jvp(f, (x,), (u,))[1]
  return jax.jvp(f_jvp, (x,), (v,))[1]

def hessian_entry(f, x, i, j):
  u = jnp.zeros_like(x).at[i].set(1)
  v = jnp.zeros_like(x).at[j].set(1)
  return hessian_bilinear_form(f, x, u, v)

A = jnp.array([[1., 2],
               [3., 4.]])
f = lambda x: jnp.dot(x, jnp.dot(A, x))

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@stefano-mossa
Comment options

Answer selected by stefano-mossa
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #8441 on November 04, 2021 02:34.