Jax not utilizing GPU for pymc (StreamExecutorGpuDevice) #12907
Unanswered
imrogerjiang
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi community,
I've been trying to use CUDA acceleration for pymc4. When I attempt to run pymc.sampling_jax.samplenumpyro_nuts the resulting MCMC sample is slow and it does not use the GPU (nvidia-smi shows 0% for the process) and uses the CPU (seen from system monitor).
It also gives the following warning:
/home/roger/.local/lib/python3.10/site-packages/pymc/sampling_jax.py:548: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using "numpyro.set_host_device_count(4)" at the beginning of your program. You can double-check how many devices are available in your system using "ax.local_device_count()"
When I run
jax.devices()
, it returns[StreamExecutorGpuDevice(id=0, process_index=0)]
and I suspect the issue could be related to this. What should I fix such that the sampling is performed on the GPU?I have followed the install instructions (including CUDA and CUDNN) found in https://github.com/google/jax#installation.
I am using Nvidia GTX1070
I am running the following packages
pymc 4.2.2
jax 0.3.23
numpyro 0.10.1
running nvcc -version gives
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2022 NVIDIA Corporation Built on Wed_Sep_21_10:33:58_PDT_2022 Cuda compilation tools, release 11.8, V11.8.89 Build cuda_11.8.r11.8/compiler.31833905_0
running nvidia-smi gives

Thank you!
Beta Was this translation helpful? Give feedback.
All reactions