-
I observed some interesting discrepancies in the behavior of the import jax.numpy as jnp
import numpy as np
import jax
print(jax.__version__)
print("\n**NumPy Behavior**")
print(np.isscalar(1.1), np.ndim(1.1)) # True 0
print(np.isscalar(False), np.ndim(False)) # True 0
print(np.isscalar('jax'), np.ndim('jax')) # True 0
print(np.isscalar(np.float32(1)), np.ndim(np.float32(1))) # True 0
print(np.isscalar(np.complex64(1)), np.ndim(np.complex64(1))) # True 0
print(np.isscalar([1.1]), np.ndim([1.1])) # False 1
print(np.isscalar(np.array(1.1)), np.ndim(np.array(1.1))) # False 0
print("\n**JAX Behavior**")
# Consistent
print(jnp.isscalar(1.1), jnp.ndim(1.1)) # True 0
print(jnp.isscalar(False), jnp.ndim(False)) # True 0
print(jnp.isscalar('jax'), jnp.ndim('jax')) # True 0
# Inconsistency for JAX-specific numeric types
print(jnp.isscalar(jnp.float32(1)), jnp.ndim(jnp.float32(1))) # False 0
print(jnp.isscalar(jnp.complex64(1)), jnp.ndim(jnp.complex64(1))) # False 0
print(jnp.isscalar([1.1]), jnp.ndim([1.1])) # False 1
print(jnp.isscalar(jnp.array(1.1)), jnp.ndim(jnp.array(1.1))) # True 0
# Inconsistency for NumPy-created arrays
print(jnp.isscalar(jnp.array(1)), jnp.ndim(jnp.array(1))) # True 0
print(jnp.isscalar(np.array(1)), jnp.ndim(np.array(1))) # False 0 Output:
Key Observations:
Questions for Discussion:
Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hi - thanks for the question! The difference is that in JAX, scalars are represented by zero-dimensional arrays, while in numpy, scalars have their own type. So in numpy In JAX, For What do you think? |
Beta Was this translation helpful? Give feedback.
-
Thank you @jakevdp for detailed answer. I think one need to consider those points and differences while using |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! The difference is that in JAX, scalars are represented by zero-dimensional arrays, while in numpy, scalars have their own type.
So in numpy
np.float32(0)
andnp.array(0, dtype='float32')
are different objects: the former has typenp.float32
, the latter has typenp.ndarray
.In JAX,
jnp.float32(0)
andjnp.array(0, dtype='float32')
are essentially the same object: they're both zero-dimensional arrays, of typejax.Array
, with dtype float32.For
jnp.isscalar
, this put us in a bit of a pickle, because in a sense there's no way to match NumPy's behavior: JAX doesn't distinguish between scalars and zero-dimensional arrays, while NumPy does! The compromise we came up …