Skip to content
Discussion options

You must be logged in to vote

Well, the first thing to realize is that in Python, type annotations are purely decorative and don't affect runtime values in normal code. So you can simplify your example function to this:

def f(v):
  return 1.0 * v
print(f(10.0))
# 10.0

This is a python function, which accepts a Python float, multiplies it by a Python float, and returns a Python float. It's just Python: JAX doesn't enter the picture at all.

Now this is different:

print(f(jnp.array([10.0]))
# [10.0]

This is a Python function which accepts a JAX array and multiplies it by a Python scalar – such multiplication is defined via the __mul__ method of JAX DeviceArray objects to return a DeviceArray.

Where implicit conversion ma…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jecampagne
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants