-
Hi! The code below shows how the function import jax
seed = 1234567890
x = jax.random.uniform(jax.random.PRNGKey(seed+10), (2, 2))
w = jax.random.uniform(jax.random.PRNGKey(seed+12), (2, 20))
def func(x, w):
y = (x + 1.).dot(w)
return y.sum()
print(func(x, w))
print(jax.jit(func, backend='gpu')(x, w))
print(jax.jit(func, backend='cpu')(x, w))
# 64.80866
# 64.76463
# 64.80866 Is this a bug, or something I should expect? What's interesting, the results are in agreement if one takes |
Beta Was this translation helpful? Give feedback.
Answered by
mrkwjc
Feb 19, 2024
Replies: 1 comment
-
Problem solved. For better accuracy the higher matmul precision must be set: jax.config.update('jax_default_matmul_precision', 'high') # 'bfloat16_3x' |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
mrkwjc
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Problem solved. For better accuracy the higher matmul precision must be set: