Replies: 1 comment
-
Inlining expressions by hand: hess_vp_1 = jvp(lambda x: vjp(func, alpha)[1](x)[0],
(alpha,), (test,))[1]
hess_vp_2 = jvp(grad(func)),
(alpha,), (test,))[1] But grad(f)(z) == vjp(f, z)[1](1.)[0] In words: (This is only a guess, based on a quick glance.) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I've implemented some custom code to perform trust region conjugate gradient optimization using JAX/LAX primitives with some success, however I've based it off analytic gradients and hessians for my specific problem. I'd love to start porting this to a general interface that mimics more closely the
minimize_*
API within JAX, but have a question regarding the efficient HVP example in the tutorial.Namely, I'd like to have a function
value_and_grad_and_hvp
that behaves similarly to thevalue_and_grad
, but also returns the hessian vector product functionhvp
. The reason for this is that while the tutorial provides an excellent example of how to compute thehvp
function, it would require multiple calls tograd
which internally re-callsvalue_and_grad
. This likely isn't a huge bottleneck, but for my particular application I am trying to squeeze as much as I can out of the implementation due to a massive number of optimizations that need to be performed.I've tried to wrap my own function that computes value, grad, and hvp directly from
jvp
andvjp
calls, but I'm having trouble figuring out why 1) the dimensionality is not lining up and 2) why the HVP values are incorrect (likely due to why 1 is off).I've re-coded up a toy example to illustrate what I'm trying to accomplish and what the contrast is. Any help or suggestions would be greatly appreciated.
thanks!
Beta Was this translation helpful? Give feedback.
All reactions