Skip to content

Conversation

@Qazalbash
Copy link
Contributor

@Qazalbash Qazalbash commented Nov 5, 2025

Current installation instructions to install numpyro on GPU is not supported for CUDA 13.

uv pip install --upgrade numpyro[cuda] --dry-run -f https://storage.googleapis.com/jax-releases/jax_releases.html
Resolved 23 packages in 1.67s
Would download 23 packages
Would install 23 packages
 + jax==0.8.0
 + jax-cuda12-pjrt==0.8.0
 + jax-cuda12-plugin==0.8.0
 + jaxlib==0.8.0
 + ml-dtypes==0.5.3
 + multipledispatch==1.0.0
 + numpy==2.3.4
 + numpyro==0.19.0
 + nvidia-cublas-cu12==12.9.1.4
 + nvidia-cuda-cupti-cu12==12.9.79
 + nvidia-cuda-nvcc-cu12==12.9.86
 + nvidia-cuda-nvrtc-cu12==12.9.86
 + nvidia-cuda-runtime-cu12==12.9.79
 + nvidia-cudnn-cu12==9.15.0.57
 + nvidia-cufft-cu12==11.4.1.4
 + nvidia-cusolver-cu12==11.7.5.82
 + nvidia-cusparse-cu12==12.5.10.65
 + nvidia-nccl-cu12==2.28.7
 + nvidia-nvjitlink-cu12==12.9.86
 + nvidia-nvshmem-cu12==3.4.5
 + opt-einsum==3.4.0
 + scipy==1.16.3
 + tqdm==4.67.1

This PR updates the setup.py along with relevant documentation to install NumPyro for both CUDA 12 and CUDA 13.

@juanitorduz juanitorduz requested a review from fehiepsi November 5, 2025 09:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant