Make GPU-JAX work when CUDA is not available #14208
Unanswered
carlosgmartin
asked this question in
Q&A
Replies: 2 comments 2 replies
-
Have you tried setting the platform? Maybe that avoids the folder check. import jax
jax.config.update("jax_platform_name", "cpu") You would have to do this at the very tippy-top of your code, so that no other JAX computation happens before it. |
Beta Was this translation helpful? Give feedback.
2 replies
-
You could set the
or in a script: import os
os.environ['JAX_PLATFORMS'] = 'cpu'
import jax |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a Slurm cluster with both GPU and non-GPU nodes.
CPU-JAX (obtained via
python3 -m pip install jax
) works on both types of nodes, but does not use the GPUs on the GPU nodes.GPU-JAX (obtained via
python3 -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
) works on the GPU nodes (and uses their GPUs), but fails on the non-GPU nodes with the following error:This is because the CUDA folder does not exist on the non-GPU nodes. What's the recommended way to address this? Can GPU-JAX be configured to ignore the above error and proceed with the CPUs only, when CUDA is not available?
Beta Was this translation helpful? Give feedback.
All reactions