Cholesky of inverse Hessian in HVP #12867
Unanswered
silasbrack
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,$H^{-1}=\Sigma$ ). Thanks to the reparameterization trick, I can sample by multiplying a set of samples $\epsilon$ (which have been sampled from a standard normal) by the Cholesky decomposition of my covariance matrix: $\mathrm{Chol}(H^{-1}) \cdot \epsilon$ [1].$\epsilon$ (i.e., $v = H^{-1} \cdot \epsilon$ ) by using the conjugate gradient method (leveraging $H \cdot v = \epsilon$ ) to find $v$ :
For context, I'm trying to run the Laplace approximation for a neural network in JAX and need to determine the Hessian matrix of my loss function wrt my NN's parameters. This is quite easy, however I don't want to store the entire Hessian because my parameter space can be too large to store the square matrix.
All I need my Hessian for is to estimate the covariance matrix of my weight posterior though, so all I really need is to be able to sample from a normal distribution with a covariance equal to the inverse of this Hessian (
I managed to multiply the inverse of my Hessian by
However, now I need to figure out how to multiply the Cholesky of my inverse-Hessian by epsilon (presumably still using the conjugate gradient):
$$v = \mathrm{Chol}(H^{-1}) \cdot \epsilon \Rightarrow \mathrm{Chol}(H)^{T} \cdot v = \epsilon$$
My understanding is that the
custom_jvp
method is for implementing custom derivative calculations and not for customising the Hessian calculation / product operation, so I suppose I can't use it in this case.Is it possible for me to perform this Cholesky-inverse-Hessian-vector product efficiently (i.e., without storing the entire square Hessian) in JAX?
Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions