Enforce float16 #6700
-
Hi, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! Just a quick clarification: As a brief example, here's numpy's promotion behavior when adding a float and an int: import numpy as np
x = np.arange(5, dtype='float16')
y = np.ones(5, dtype=int)
print((x + y).dtype)
# float64 and here's the equivalent in JAX: import jax.numpy as jnp
x = jnp.arange(5, dtype='bfloat16')
y = jnp.ones(5, dtype=int)
print((x + y).dtype)
# bfloat16 |
Beta Was this translation helpful? Give feedback.
Thanks for the question! Just a quick clarification:
jax_enable_x64
does not enforce 64-bit precision: it allows 64-bit to be used. To answer your question directly: no, there is no way to enforce 16-bit precision for all calculations. However, JAX's type promotion semantics have been designed specifically to make it easier to maintain whatever precision you desire by explicitly setting the dtype of values you create.As a brief example, here's numpy's promotion behavior when adding a float and an int:
and here's the equivalent in JAX: