Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,21 +232,29 @@ See the [docs](https://num.pyro.ai/en/latest/contrib.html#stein-variational-infe

To install NumPyro with the latest CPU version of JAX, you can use pip:

```
```bash
pip install numpyro
```

In case of compatibility issues arise during execution of the above command, you can instead force the installation of a known
compatible CPU version of JAX with

```
```bash
pip install 'numpyro[cpu]'
```

To use **NumPyro on the GPU**, you need to install CUDA first and then use the following pip command:
To use **NumPyro on the GPU**, you need to install CUDA first, and based on your CUDA version, you can use the following pip command:

For **CUDA 12.x.y**:

```bash
pip install 'numpyro[cuda12]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

For **CUDA 13.x.y**:

```bash
pip install 'numpyro[cuda13]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/jax-ml/jax#pip-installation-gpu-cuda).
Expand All @@ -261,7 +269,7 @@ you can install NumPyro using the `pip install numpyro` command.

You can also install NumPyro from source:

```
```bash
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
Expand All @@ -270,7 +278,7 @@ pip install -e '.[dev]' # contains additional dependencies for NumPyro developm

You can also install NumPyro with conda:

```
```bash
conda install -c conda-forge numpyro
```

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@
# TPU and CUDA installations, currently require to add package repository URL, i.e.,
# pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_releases.html
"tpu": f"jax[tpu]{_jax_version_constraints}",
"cuda": f"jax[cuda]{_jax_version_constraints}",
"cuda12": f"jax[cuda12]{_jax_version_constraints}",
"cuda13": f"jax[cuda13]{_jax_version_constraints}",
},
python_requires=">=3.9",
long_description=long_description,
Expand Down