You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For a function h(x;\theta): R^n->R, I would like to get all second order derivatives d^2h(x)/dx_idx_i, i.e., the diagonal elements in the Hessian matrix. If I can come up with that function, I want to take the derivative of it with respect to function parameters: d^3h(x)/dx_idx_id\theta
I came up with the following implementation; it seems to work on toy examples.
# an example of a nonlinear hypothesis
def hypothesis(params, x):
w0, b0 = params
y = jnp.sum(w0*x) + b0
return y*y
# first order gradient
def dhdx(params, x):
return jax.grad(hypothesis, argnums=1)(params, x)
def helper(params, x, mask):
return jnp.dot(dhdx(params,x), mask)
# second order gradient
def h2d(params, x, mask):
return jnp.dot(jax.grad(helper, argnums=1)(params,x,mask), mask)
Using different mask I can get d^2h(x)/dx_idx_i for different x_i. For example mask = jnp.array([1., 0., ..., 0.]) gives me partial derivative for x_1; mask = jnp.array([0., 1., 0., ..., 0.]) gives me partial derivative for x_2
By doing this, can I successfully avoid computation to evaluate the entire Hessian?
I am very new to JAX and any suggestions are welcome. Thank you!
Update:
Solution in #3801 (comment) does calculate all the diagonal element in Hessian; see code below. But I don't know how to get the gradient of Hessian with respect to parameter (e.g., param a in code below): d^3f(x)/dx^2da. Is there a way to "parameterize" the Hessian function and take partial derivative over parameters?
from jax import jvp, grad, hessian
import jax.numpy as jnp
import numpy.random as npr
rng = npr.RandomState(0)
a = rng.randn(4)
x = rng.randn(4)
# function with diagonal Hessian that isn't rank-polymorphic
def f(x):
assert x.ndim == 1
return jnp.sum(jnp.tanh(a * x))
def hvp(f, x, v):
return jvp(grad(f), (x,), (v,))[1]
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
For a function h(x;\theta): R^n->R, I would like to get all second order derivatives d^2h(x)/dx_idx_i, i.e., the diagonal elements in the Hessian matrix. If I can come up with that function, I want to take the derivative of it with respect to function parameters: d^3h(x)/dx_idx_id\theta
I came up with the following implementation; it seems to work on toy examples.
Using different mask I can get d^2h(x)/dx_idx_i for different x_i. For example
mask = jnp.array([1., 0., ..., 0.])
gives me partial derivative for x_1;mask = jnp.array([0., 1., 0., ..., 0.])
gives me partial derivative for x_2My questions are:
I am very new to JAX and any suggestions are welcome. Thank you!
Update:
Solution in #3801 (comment) does calculate all the diagonal element in Hessian; see code below. But I don't know how to get the gradient of Hessian with respect to parameter (e.g., param a in code below): d^3f(x)/dx^2da. Is there a way to "parameterize" the Hessian function and take partial derivative over parameters?
Beta Was this translation helpful? Give feedback.
All reactions