Replies: 1 comment
-
There's not any easy way to do this in general, but it may be possible for special cases; see the previous discussion here: #3801 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
To make this question more concrete, let's consider an example neural network in Flax
Suppose we are training the neural network on some data. For the purpose of this question,
batch
andlabel
are considered constant, so let's define a helper functionf
. We calculate the loss as followsWe can calculate the gradient with
jax.grad
.Similarly, we can calculate the Hessian with
jax.hessian
.The question is how to calculate the diagonal of Hessian efficiently. In particular, if we concatenate all parameters into a huge vector of size$N$ , then the Hessian will be a matrix of size $N \times N$ , whose diagonal is a vector of size $N$ , and this is the desired "Hessian diagonal". Essentially, for each scalar component in the parameter vector, I want to fix all other parameters and calculate the second-order derivative of the loss function with respect to that particular component. In addition, I want the Hessian diagonal to have the same shape as the parameters and gradient, i.e.
Technically I can form the full Hessian and extract its diagonal, but that is going to have quadratic time complexity with respect to the number of parameters. Is there a more efficient solution in linear time?
Beta Was this translation helpful? Give feedback.
All reactions