Skip to content
Discussion options

You must be logged in to vote

See: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

JAX enforces 32-bit calculations by default, and the different output comes from performing the calculation in 32-bit precision.

You would receive the same output from a NumPy float32 from Sq if you instead set:

x = np.float32(4.2)

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
1 reply
@EngineerKhan
Comment options

Comment options

You must be logged in to vote
3 replies
@EngineerKhan
Comment options

@jakevdp
Comment options

@EngineerKhan
Comment options

Answer selected by EngineerKhan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants