You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Added optional installation target for GPU jaxlib.
* Code review fixes to new installation options.
Bumped jax version to 0.2.13 because jax isntall targets are not available before that.
* Adding separate installation option for pinned jaxlib.
* Simplifying TPU install instructions.
* PR review fixes.
Copy file name to clipboardExpand all lines: README.md
+19-3Lines changed: 19 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -184,15 +184,31 @@ Pyro users will note that the API for model specification and inference is large
184
184
185
185
> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows.
186
186
187
-
To install NumPyro with a CPU version of JAX, you can use pip:
187
+
To install NumPyro with the latest CPU version of JAX, you can use pip:
188
188
189
189
```
190
190
pip install numpyro
191
191
```
192
192
193
-
To use NumPyro on the GPU, you will need to first [install](https://github.com/google/jax#installation)`jax` and `jaxlib` with CUDA support.
193
+
In case of compatibility issues arise during execution of the above command, you can instead force the installation of a known
194
+
compatible CPU version of JAX with
194
195
195
-
To run NumPyro on Cloud TPUs, you can use pip to install NumPyro as above and setup the TPU backend as detailed [here](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
196
+
```
197
+
pip install numpyro[cpu]
198
+
```
199
+
200
+
To use **NumPyro on the GPU**, you need to install CUDA first and then use the following pip command:
201
+
```
202
+
# change `cuda111` to your CUDA version number, e.g. for CUDA 10.2 use `cuda102`
If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda).
206
+
207
+
To run **NumPyro on Cloud TPUs**, you can look at some [JAX on Cloud TPU examples](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
208
+
209
+
For Cloud TPU VM, you need to setup the TPU backend as detailed in the [Cloud TPU VM JAX Quickstart Guide](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
210
+
After you have verified that the TPU backend is properly set up,
211
+
you can install NumPyro using the `pip install numpyro` command.
196
212
197
213
> **Default Platform:** JAX will use GPU by default if CUDA-supported `jaxlib` package is installed. You can use [set_platform](http://num.pyro.ai/en/stable/utilities.html#set-platform) utility `numpyro.set_platform("cpu")` to switch to CPU at the beginning of your program.
0 commit comments