diff --git a/README.md b/README.md index ccdcf4bed..014060211 100644 --- a/README.md +++ b/README.md @@ -232,21 +232,29 @@ See the [docs](https://num.pyro.ai/en/latest/contrib.html#stein-variational-infe To install NumPyro with the latest CPU version of JAX, you can use pip: -``` +```bash pip install numpyro ``` In case of compatibility issues arise during execution of the above command, you can instead force the installation of a known compatible CPU version of JAX with -``` +```bash pip install 'numpyro[cpu]' ``` -To use **NumPyro on the GPU**, you need to install CUDA first and then use the following pip command: +To use **NumPyro on the GPU**, you need to install CUDA first, and based on your CUDA version, you can use the following pip command: +For **CUDA 12.x.y**: + +```bash +pip install 'numpyro[cuda12]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` -pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +For **CUDA 13.x.y**: + +```bash +pip install 'numpyro[cuda13]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/jax-ml/jax#pip-installation-gpu-cuda). @@ -261,7 +269,7 @@ you can install NumPyro using the `pip install numpyro` command. You can also install NumPyro from source: -``` +```bash git clone https://github.com/pyro-ppl/numpyro.git cd numpyro # install jax/jaxlib first for CUDA support @@ -270,7 +278,7 @@ pip install -e '.[dev]' # contains additional dependencies for NumPyro developm You can also install NumPyro with conda: -``` +```bash conda install -c conda-forge numpyro ``` diff --git a/setup.py b/setup.py index 926a28999..02d951b1a 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,8 @@ # TPU and CUDA installations, currently require to add package repository URL, i.e., # pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_releases.html "tpu": f"jax[tpu]{_jax_version_constraints}", - "cuda": f"jax[cuda]{_jax_version_constraints}", + "cuda12": f"jax[cuda12]{_jax_version_constraints}", + "cuda13": f"jax[cuda13]{_jax_version_constraints}", }, python_requires=">=3.9", long_description=long_description,