-
Hello all, I am fairly new to optax, and am trying to call
Edit: Evaluating
l-bfgs works robustly on the problem, and much better than sgd or adam, so I'd rather not switch to another algorithm. Since it's the tangent, not the primal value overflowing, I don't quite know what to do. Is it possible to add a stopping criterion, or modify the lbfgs somehow to prevent this? Thanks a lot in advance!🙏🙏🙏 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
Oh, this is such an interesting failure mode! I'd love to take a look! Perhaps this might be related #1189, but it's just a guess. Thank you for the repro, I'll try it out. In the meantime if you want to push on it yourself as well, maybe useful here: https://docs.jax.dev/en/latest/debugging/flags.html |
Beta Was this translation helpful? Give feedback.
-
I was able to run your repro, unfortunately NaN debugging utils in jax are still work in progress. I think in your case the linesearch is what causes NaNs in gradients, I tried replacing the default zoom linesearch in lbfgs with: One possible temporary solution would be implementing your own linesearch in a way that's gradient safe. I'll add investigating the lineasearch gradient stability to my TODO. |
Beta Was this translation helpful? Give feedback.
The local minimization procedures in the linesearch might be to blame given it does a lot of divisions. You could try implementing a simple
scale-by-0.7-until-fn-value-is-lower-than-last-value
for numerical stability.Another possibility could be using implicit differentiation like here or here