Why jvp imposes primals and tangents to be on the same dtype? #13223
Answered
by
jakevdp
mariogeiger
asked this question in
Q&A
-
Why is it not possible to compute the following? import jax
import jax.numpy as jnp
jax.jvp(jnp.sin, (0.1,), (1. + 1j,)) # ERROR
|
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Nov 13, 2022
Replies: 1 comment
-
In general JAX does not do implicit type promotion in binary operations (e.g. try running The |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
mariogeiger
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In general JAX does not do implicit type promotion in binary operations (e.g. try running
jax.lax.add(1, 1.0)
) and this convention extends to JVP rules. This kind of stricter convention can help catch bugs, and so it's useful to enforce for lower-level APIs.The
jax.numpy
package, of course, does do automatic type promotion for many operations (e.g.jnp.add(1, 1.0)
), mainly because numpy itself supports this kind of operation andjax.numpy
mirrors that API.