Skip to content
Discussion options

You must be logged in to vote

Thanks for the clear question!

The difference here is that for 32-bit inputs, JAX performs its operations in 32-bit, while numpy performs its operations in 64-bit and then casts the output back to 32-bit. JAX operates this way because it's designed for use on accelerators, where 64-bit operations can be significantly slower than their 32-bit counterparts. NumPy is only designed for CPU, and so it does not share this concern.

If you want to replicate the numpy results with JAX, you need to enable x64 mode and then do this casting explicitly. For example:

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
import numpy as 

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@bevector
Comment options

Answer selected by bevector
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants