Skip to content
Discussion options

You must be logged in to vote

There's no mechanism for warning on dtype promotion, but it is possible to enable strict promotion mode in which type promotion between different dtypes will raise an error. For example:

import jax
import jax.numpy as jnp
jax.config.update('jax_numpy_dtype_promotion', 'strict')
jax.numpy.add(jnp.arange(5, dtype=float), jnp.arange(5, dtype=int))
# TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
#   dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
#   inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by imoneoi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants