Skip to content

Commit f368d98

Browse files
authored
Easier installation for GPU/TPU (#1079)
* Added optional installation target for GPU jaxlib. * Code review fixes to new installation options. Bumped jax version to 0.2.13 because jax isntall targets are not available before that. * Adding separate installation option for pinned jaxlib. * Simplifying TPU install instructions. * PR review fixes.
1 parent f8f482a commit f368d98

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

README.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,31 @@ Pyro users will note that the API for model specification and inference is large
184184

185185
> **Limited Windows Support:** Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details. Alternatively, you can install [Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/) and use NumPyro on it as on a Linux system. See also [CUDA on Windows Subsystem for Linux](https://developer.nvidia.com/cuda/wsl) and [this forum post](https://forum.pyro.ai/t/numpyro-with-gpu-works-on-windows/2690) if you want to use GPUs on Windows.
186186
187-
To install NumPyro with a CPU version of JAX, you can use pip:
187+
To install NumPyro with the latest CPU version of JAX, you can use pip:
188188

189189
```
190190
pip install numpyro
191191
```
192192

193-
To use NumPyro on the GPU, you will need to first [install](https://github.com/google/jax#installation) `jax` and `jaxlib` with CUDA support.
193+
In case of compatibility issues arise during execution of the above command, you can instead force the installation of a known
194+
compatible CPU version of JAX with
194195

195-
To run NumPyro on Cloud TPUs, you can use pip to install NumPyro as above and setup the TPU backend as detailed [here](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
196+
```
197+
pip install numpyro[cpu]
198+
```
199+
200+
To use **NumPyro on the GPU**, you need to install CUDA first and then use the following pip command:
201+
```
202+
# change `cuda111` to your CUDA version number, e.g. for CUDA 10.2 use `cuda102`
203+
pip install numpyro[cuda111] -f https://storage.googleapis.com/jax-releases/jax_releases.html
204+
```
205+
If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda).
206+
207+
To run **NumPyro on Cloud TPUs**, you can look at some [JAX on Cloud TPU examples](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
208+
209+
For Cloud TPU VM, you need to setup the TPU backend as detailed in the [Cloud TPU VM JAX Quickstart Guide](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
210+
After you have verified that the TPU backend is properly set up,
211+
you can install NumPyro using the `pip install numpyro` command.
196212

197213
> **Default Platform:** JAX will use GPU by default if CUDA-supported `jaxlib` package is installed. You can use [set_platform](http://num.pyro.ai/en/stable/utilities.html#set-platform) utility `numpyro.set_platform("cpu")` to switch to CPU at the beginning of your program.
198214

setup.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
from setuptools import find_packages, setup
1010

1111
PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
12+
_available_cuda_versions = [
13+
"101",
14+
"102",
15+
"110",
16+
"111",
17+
] # TODO: align these with what's available in JAX before release
18+
_jax_version_constraints = ">=0.2.13"
19+
_jaxlib_version_constraints = ">=0.1.65"
1220

1321
# Find version
1422
for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")):
@@ -23,7 +31,6 @@
2331
sys.stderr.flush()
2432
long_description = ""
2533

26-
2734
setup(
2835
name="numpyro",
2936
version=version,
@@ -32,8 +39,8 @@
3239
url="https://github.com/pyro-ppl/numpyro",
3340
author="Uber AI Labs",
3441
install_requires=[
35-
"jax>=0.2.11",
36-
"jaxlib>=0.1.62",
42+
f"jax{_jax_version_constraints}",
43+
f"jaxlib{_jaxlib_version_constraints}",
3744
"tqdm",
3845
],
3946
extras_require={
@@ -66,6 +73,14 @@
6673
"tfp-nightly<=0.14.0.dev20210608",
6774
],
6875
"examples": ["arviz", "jupyter", "matplotlib", "pandas", "seaborn"],
76+
"cpu": f"jax[cpu]{_jax_version_constraints}",
77+
# TPU and CUDA installations, currently require to add package repository URL, i.e.,
78+
# pip install numpyro[cuda101] -f https://storage.googleapis.com/jax-releases/jax_releases.html
79+
"tpu": f"jax[tpu]{_jax_version_constraints}",
80+
**{
81+
f"cuda{version}": f"jax[cuda{version}]{_jax_version_constraints}"
82+
for version in _available_cuda_versions
83+
},
6984
},
7085
long_description=long_description,
7186
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)