-
I have a project that is built on top of another large library. I want to add an extension at the end of the calculation that requires some precise numerics in float64, but only for a small part of the calculation (essentially, I have to calculate the difference between a large and a very small vector and take the norm of this difference, which seems to be very unstable in float32 and then take the derivative of this). I know that I can enable float64 with |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Unfortunately, that's not really possible. |
Beta Was this translation helpful? Give feedback.
Unfortunately, that's not really possible.
jax_enable_x64
is a global setting for the Python runtime. But withjax_enable_x64=True
, there's nothing to stop you from doing float32 computations: just make sure to set the dtypes of your arrays explicitly. We've done a lot of work to ensure that for 32-bit inputs you'll get 32-bit computations and 32-bit outputs, even when X64 is enabled.