jax.numpy and implicit conversion? #10625
-
Hello, I'm asking about the way jax.numpy proceeds in the case of typed argument of a function, or to convert result to the typed return value. Here is a snippet: print(1.0*jnp.array([10.]))
def test(v: jnp.array)->jnp.array:
return 1.0*v
print(test(jnp.array([10.])))
print(test(10.)) this leads to
The last print tends to indicate that passing a Does is this documented somewhere? So as a consequence, the code should add an explicit conversion of the result of computation in test, to a jnp.array.? Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Well, the first thing to realize is that in Python, type annotations are purely decorative and don't affect runtime values in normal code. So you can simplify your example function to this: def f(v):
return 1.0 * v
print(f(10.0))
# 10.0 This is a python function, which accepts a Python float, multiplies it by a Python float, and returns a Python float. It's just Python: JAX doesn't enter the picture at all. Now this is different: print(f(jnp.array([10.0]))
# [10.0] This is a Python function which accepts a JAX array and multiplies it by a Python scalar – such multiplication is defined via the Where implicit conversion may enter the picture is if your function is wrapped in a transform like @jit
def f(v):
return 1.0 * v
print(repr(f(10.0)))
# DeviceArray(10., dtype=float32, weak_type=True) On input to a JIT-compiled function, a Python scalar will be implicitly converted to a JAX device array, essentially using a routine similar to Does that answer your question? |
Beta Was this translation helpful? Give feedback.
Well, the first thing to realize is that in Python, type annotations are purely decorative and don't affect runtime values in normal code. So you can simplify your example function to this:
This is a python function, which accepts a Python float, multiplies it by a Python float, and returns a Python float. It's just Python: JAX doesn't enter the picture at all.
Now this is different:
This is a Python function which accepts a JAX array and multiplies it by a Python scalar – such multiplication is defined via the
__mul__
method of JAX DeviceArray objects to return a DeviceArray.Where implicit conversion ma…