-
Hi, import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import numpy as np
# %%
mse = lambda y, y_hat: ((y-y_hat)**2).mean()
key = jax.random.PRNGKey(0)
n = 9
a = np.array(jax.random.uniform(key, shape=[n,n]), dtype=np.float32)
b = np.array(jax.random.uniform(key, shape=[n]), dtype=np.float32)
c_np = np.einsum('ij,j', a, b)
c_np_t = np.einsum('ji,j', a.T, b)
c_jnp = jnp.einsum('ij,j', a, b)
c_jnp_t = jnp.einsum('ji,j', a.T, b)
print(f"{jax.__version__=}")
print(f"{mse(c_np, c_np_t)=}")
print(f"{mse(c_np, c_jnp)=}")
print(f"{mse(c_jnp, c_jnp_t)=}")
print(f"{mse(c_np_t, c_jnp_t)=}") outputs:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the clear question! The difference here is that for 32-bit inputs, JAX performs its operations in 32-bit, while numpy performs its operations in 64-bit and then casts the output back to 32-bit. JAX operates this way because it's designed for use on accelerators, where 64-bit operations can be significantly slower than their 32-bit counterparts. NumPy is only designed for CPU, and so it does not share this concern. If you want to replicate the numpy results with JAX, you need to enable x64 mode and then do this casting explicitly. For example: import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
import numpy as np
# %%
mse = lambda y, y_hat: ((y-y_hat)**2).mean()
key = jax.random.PRNGKey(0)
n = 9
a = np.array(jax.random.uniform(key, shape=[n,n]), dtype=np.float32)
b = np.array(jax.random.uniform(key, shape=[n]), dtype=np.float32)
a64 = a.astype('float64')
b64 = b.astype('float64')
c_np = np.einsum('ij,j', a, b)
c_np_t = np.einsum('ji,j', a.T, b)
c_jnp = jnp.einsum('ij,j', a64, b64).astype('float32')
c_jnp_t = jnp.einsum('ji,j', a64.T, b64).astype('float32')
print(f"{jax.__version__=}")
print(f"{mse(c_np, c_np_t)=}")
print(f"{mse(c_np, c_jnp)=}")
print(f"{mse(c_jnp, c_jnp_t)=}")
print(f"{mse(c_np_t, c_jnp_t)=}")
|
Beta Was this translation helpful? Give feedback.
Thanks for the clear question!
The difference here is that for 32-bit inputs, JAX performs its operations in 32-bit, while numpy performs its operations in 64-bit and then casts the output back to 32-bit. JAX operates this way because it's designed for use on accelerators, where 64-bit operations can be significantly slower than their 32-bit counterparts. NumPy is only designed for CPU, and so it does not share this concern.
If you want to replicate the numpy results with JAX, you need to enable x64 mode and then do this casting explicitly. For example: