-
Can anyone please explain this behaviour to me? I am still very new to jax, and this is my first attempt with vmap.
And, for example, get:
Is the difference in precision/values my error, or am I missing something with vmap perhaps? I note that when I drop the value of "v" to 1, the discrepancy disappears, e.g.:
Thanks in advance for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
This is normal floating point precision error: when you add floating point numbers in different orders, the results often have differences of order 1E-8 for single precision arithmetic. The two operations lower to essentially the same thing (a import numpy as np
import jax.numpy as jnp
from jax import device_put
stdfloat = jnp.float32
n = 6
v = 8
np.random.seed(0)
vec = device_put(np.random.uniform(low=-1, high=1, size=(n,v)).astype(stdfloat))
mat = device_put(np.random.uniform(low=-1, high=1, size=(n,n)).astype(stdfloat))
def mydot(mat, vec):
return jnp.dot(mat, vec, precision=('float32','float32'))
out1 = mydot(mat, vec[:, 0])
out2 = mydot(mat[:, ::-1], vec[::-1, 0])
print(out1 - out2)
[ 0.0000000e+00 5.9604645e-08 0.0000000e+00 -5.9604645e-08 0.0000000e+00 2.9802322e-08] There's not really any way to "fix" this, as it's basically a fact of life when working with floating point arithmetic in any framework. |
Beta Was this translation helpful? Give feedback.
-
These errors could compound with rounding:
outputs:
And
outputs:
|
Beta Was this translation helpful? Give feedback.
This is normal floating point precision error: when you add floating point numbers in different orders, the results often have differences of order 1E-8 for single precision arithmetic.
The two operations lower to essentially the same thing (a
dot_general
primitive call), but in the second case the compiler likely chooses to compute things in a different order. You can reproduce a similar difference withoutvmap
by reversing the entries before computing a matrix-vector product; for example: