Skip to content
Discussion options

You must be logged in to vote

In general JAX does not do implicit type promotion in binary operations (e.g. try running jax.lax.add(1, 1.0)) and this convention extends to JVP rules. This kind of stricter convention can help catch bugs, and so it's useful to enforce for lower-level APIs.

The jax.numpy package, of course, does do automatic type promotion for many operations (e.g. jnp.add(1, 1.0)), mainly because numpy itself supports this kind of operation and jax.numpy mirrors that API.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mariogeiger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants