Skip to content
Discussion options

You must be logged in to vote

Thanks for the question! Just a quick clarification: jax_enable_x64 does not enforce 64-bit precision: it allows 64-bit to be used. To answer your question directly: no, there is no way to enforce 16-bit precision for all calculations. However, JAX's type promotion semantics have been designed specifically to make it easier to maintain whatever precision you desire by explicitly setting the dtype of values you create.

As a brief example, here's numpy's promotion behavior when adding a float and an int:

import numpy as np
x = np.arange(5, dtype='float16')
y = np.ones(5, dtype=int)
print((x + y).dtype)
# float64

and here's the equivalent in JAX:

import jax.numpy as jnp
x = jnp.arange(5, dtype=

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@astanziola
Comment options

Answer selected by astanziola
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants