jax.lax.scan can have less numerical precision than a for loop #27655
Replies: 1 comment
-
Interesting comparison! One important piece here is that you're not only comparing scan vs for-loop; you're actually comparing native Python floating point math vs XLA floating point math. If you check Regarding the residuals: keep in mind that floating point math is only an approximation of real math, and in general different ways of doing the "same" floating point operation will acumulate rounding error differently. It looks like the rounding error in this case is roughly what you'd expect: about 1 part in 10^16, but when you raise (1 + eps) to the power of 10,000, you get a very large baseline. If you look at relative rather than absolute error in the output (by returning |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Anyone have insight why these two functions would behave differently under jax.lax.scan?
It's very interesting...for both functions I can roll them out using a for loop or using jax.lax.scan, and then compute the "one step residuals," i.e. (carry_{t+1} - step(carry{t})
For 3 of the 4 versions (2 _step functions, and then either loop or scan), I get all these resiudals to be zero, as they should. But if I scan the _step_add, I get numerical error growing in the sequence length.
Any ideas what's going on?
Heres a repro: https://colab.research.google.com/drive/18cz0q6RNAAs9o1wn58alarbFF0rKp-5Z?usp=sharing
Beta Was this translation helpful? Give feedback.
All reactions