Skip to content
Discussion options

You must be logged in to vote

This is normal floating point precision error: when you add floating point numbers in different orders, the results often have differences of order 1E-8 for single precision arithmetic.

The two operations lower to essentially the same thing (a dot_general primitive call), but in the second case the compiler likely chooses to compute things in a different order. You can reproduce a similar difference without vmap by reversing the entries before computing a matrix-vector product; for example:

import numpy as np
import jax.numpy as jnp
from jax import device_put

stdfloat = jnp.float32

n = 6
v = 8

np.random.seed(0)
vec = device_put(np.random.uniform(low=-1, high=1, size=(n,v)).astype(stdfloat

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@rog77
Comment options

Answer selected by rog77
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants