-
First of all, I have become a JAX fanboy. I have started using it for everything. Thanks for this. For an application I require high precision square root for relatively large floats (in the order of millions). I have noticed that numpy has much higher precision sqrt compared to jax. Is there a similar way, compared to e.g. jnp.dot to control the precision of sqrt? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
JAX executes code in float32 precision by default. If you want float64 precision (similar to Numpy), you can enable it using the |
Beta Was this translation helpful? Give feedback.
-
I expected some kind of ultra-fast sqrt implementation on the GPU for fast inverse square root computations but didn't realize it was just the 32bit precision. Thanks! |
Beta Was this translation helpful? Give feedback.
JAX executes code in float32 precision by default. If you want float64 precision (similar to Numpy), you can enable it using the
--jax_enable_x64
flag; see 🔪 JAX - The Sharp Bits 🔪 : Double (64bit) precision for more information.