issue with d type float and GPU/TPU #7194
Unanswered
kyrieheard
asked this question in
Q&A
Replies: 1 comment
-
Hi - thanks for the question and the clear reproduction! The first warning you're getting is normal on a CPU backend – JAX prefers GPU or TPU when available, and will warn you when falling back to CPU unless you explicitly request it (see #6805) The second warning you're getting is due to the fact that you're requesting float64 types when running outside X64 mode; see 🔪 JAX - The Sharp Bits 🔪: Double (64bit) precision for more information on this. With that in mind, I find that the expected output is produced if you enable X64-mode on your script by putting these lines at the top: from jax.config import config
config.update("jax_enable_x64", True) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm using code from an article for a project but I can't get it to work.
Output is supposed to be
But i'm getting this
Here is the code
Beta Was this translation helpful? Give feedback.
All reactions