Hessian calculation #8456
-
I have two (probably) trivial questions:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Great questions! For your first question, the answer is yes, basically by using autodiff to get access to the function More concretely: import jax
import jax.numpy as jnp
def hessian_bilinear_form(f, x, u, v):
# forward-over-reverse autodiff
f_jvp = lambda x: jax.jvp(f, (x,), (u,))[1]
return jax.jvp(f_jvp, (x,), (v,))[1]
def hessian_entry(f, x, i, j):
u = jnp.zeros_like(x).at[i].set(1)
v = jnp.zeros_like(x).at[j].set(1)
return hessian_bilinear_form(f, x, u, v)
A = jnp.array([[1., 2],
[3., 4.]])
f = lambda x: jnp.dot(x, jnp.dot(A, x))
H_01 = hessian_entry(f, jnp.array([1., 2.]), 0, 1)
print(H_01)
H = A + A.T
print(H[0, 1]) For your second question, that's correct: it's not exploiting any symmetry. The Let me know if that covers what you had in mind, or if you spot any mistakes! |
Beta Was this translation helpful? Give feedback.
Great questions!
For your first question, the answer is yes, basically by using autodiff to get access to the function
(u, v) -> u'Hv
. With that function you can extract any element of the Hessian by feeding in standard basis vectors foru
andv
.More concretely: