Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! The difference is that in JAX, scalars are represented by zero-dimensional arrays, while in numpy, scalars have their own type.

So in numpy np.float32(0) and np.array(0, dtype='float32') are different objects: the former has type np.float32, the latter has type np.ndarray.

In JAX, jnp.float32(0) and jnp.array(0, dtype='float32') are essentially the same object: they're both zero-dimensional arrays, of type jax.Array, with dtype float32.

For jnp.isscalar, this put us in a bit of a pickle, because in a sense there's no way to match NumPy's behavior: JAX doesn't distinguish between scalars and zero-dimensional arrays, while NumPy does! The compromise we came up …

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Answer selected by selamw1
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants