-
Can we enable type promotion warning just like the rank promotion warning? I want to check whether an array is promoted to float64 when using x64. |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Oct 4, 2022
Replies: 1 comment
-
There's no mechanism for warning on dtype promotion, but it is possible to enable strict promotion mode in which type promotion between different dtypes will raise an error. For example: import jax
import jax.numpy as jnp
jax.config.update('jax_numpy_dtype_promotion', 'strict')
jax.numpy.add(jnp.arange(5, dtype=float), jnp.arange(5, dtype=int))
# TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
# dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
# inputs to the desired output type, or set jax_numpy_dtype_promotion=standard. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
imoneoi
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
There's no mechanism for warning on dtype promotion, but it is possible to enable strict promotion mode in which type promotion between different dtypes will raise an error. For example: