Replies: 3 comments 6 replies
-
Generally optax would stay on a higher level, leaving an optimization like this—e.g., if there's an opportunity to reuse gradients—to the XLA compiler thanks to jax.jit. Do you have a specific application in mind? It's possible there are ways to express computation reuse in a way that'll be easier for the compiler to pick up on an optimization like that. |
Beta Was this translation helpful? Give feedback.
-
I think you should be able to create a joint optimization objective (likely by summing losses for all cases) and run a global optax optimization); vmap and compiling the whole step or the whole run should work well.
For a parallel line search you should be able to write your own optax linesearch, or we can explore adding it to optax. |
Beta Was this translation helpful? Give feedback.
-
Just to complete a trivial example we will use the Branin function:
then we know the 3 local (global) optimia are at
and the typical domain is But we could also just:
(still trying to understand those two values that are not local optimum...) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Is there benefit to using vmap to start multiple optimizations in parallel? That is, would there be computational "savings" with gradient reuse?
Beta Was this translation helpful? Give feedback.
All reactions