Can JIT change values by working with a different precision? #12020
-
I am facing a situation where a function returns different values when JITed vs. pure Python (it is difficult to make a minimum working example). When examining, I noticed that the reason might be related to the float32 precision. The following code strikes me as odd: >>> a = jnp.array(0.006963793188333511, dtype=jnp.float32)
>>> b = jnp.array(0.006963792722672224, dtype=jnp.float32)
>>> a
DeviceArray(0.00696379, dtype=float32)
>>> b
DeviceArray(0.00696379, dtype=float32)
>>> a == b
DeviceArray(False, dtype=bool) Is this behaviour to be expected? (I would have thought that the repr The main reason I am interested is that the JITed function produces |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Two float values having the same repr does not necessarily mean they contain the same bits. You can adjust the representation using the >>> jnp.set_printoptions(precision=10)
>>> a
DeviceArray(0.006963793, dtype=float32)
>>> b
DeviceArray(0.0069637927, dtype=float32) As for why JIT might lead to slightly different floating point outputs, this is covered in the FAQ: https://jax.readthedocs.io/en/latest/faq.html#jit-changes-the-exact-numerics-of-outputs Hope that helps! |
Beta Was this translation helpful? Give feedback.
Two float values having the same repr does not necessarily mean they contain the same bits. You can adjust the representation using the
set_printoptions
function, which JAX inherits from numpy. For example:As for why JIT might lead to slightly different floating point outputs, this is covered in the FAQ: https://jax.readthedocs.io/en/latest/faq.html#jit-changes-the-exact-numerics-of-outputs
Hope that helps!