Skip to content
Discussion options

You must be logged in to vote

Two float values having the same repr does not necessarily mean they contain the same bits. You can adjust the representation using the set_printoptions function, which JAX inherits from numpy. For example:

>>> jnp.set_printoptions(precision=10)
>>> a
DeviceArray(0.006963793, dtype=float32)
>>> b
DeviceArray(0.0069637927, dtype=float32)

As for why JIT might lead to slightly different floating point outputs, this is covered in the FAQ: https://jax.readthedocs.io/en/latest/faq.html#jit-changes-the-exact-numerics-of-outputs

Hope that helps!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@onnoeberhard
Comment options

Answer selected by onnoeberhard
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants