Replies: 2 comments
-
I think it makes sense to add this option. Would you be interested in contributing a PR? If you're interested in a quick workaround, optax has a very simple implementation of this functionality which you could directly reproduce in your code Line 287 in 9b682ab |
Beta Was this translation helpful? Give feedback.
0 replies
-
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have an expensive loss function that returns a lot of intermediate results as auxiliary data.
I want to optimize it with LBFGS (with line-search). However, to cache the calculated gradients, I need to use
opt.value_and_grad_from_state
but it does not allow to thread auxiliary values.For a simple case, I have adapted the example into a MWE.
Is there an option like
has_aux
as is available injax.value_and_grad
Beta Was this translation helpful? Give feedback.
All reactions