File tree Expand file tree Collapse file tree 2 files changed +3
-9
lines changed Expand file tree Collapse file tree 2 files changed +3
-9
lines changed Original file line number Diff line number Diff line change @@ -238,7 +238,7 @@ pip install numpyro[cpu]
238238
239239To use ** NumPyro on the GPU** , you need to install CUDA first and then use the following pip command:
240240```
241- pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases .html
241+ pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases .html
242242```
243243If you need further guidance, please have a look at the [ JAX GPU installation instructions] ( https://github.com/google/jax#pip-installation-gpu-cuda ) .
244244
Original file line number Diff line number Diff line change @@ -8,9 +8,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04
88
99# declare the image name
1010# note that this image uses Python 3.8
11- ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
12- # declare the cuda version for pulling appropriate jaxlib wheel
13- JAXLIB_CUDA=111
11+ ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04
1412
1513# install python3 and pip on top of the base Ubuntu image
1614# unlike for release, we need to install git and setuptools too
@@ -22,11 +20,7 @@ RUN apt update && \
2220ENV PATH=/root/.local/bin:$PATH
2321
2422# install python packages via pip
25- # install pip-versions to detect the latest version of jax and jaxlib
26- RUN pip3 install pip-versions
27- # this uses latest version of jax and jaxlib available from pypi
28- RUN pip-versions latest jaxlib | xargs -I{} pip3 install jaxlib=={}+cuda${JAXLIB_CUDA} -f https://storage.googleapis.com/jax-releases/jax_releases.html \
29- jax
23+ RUN pip3 install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
3024
3125# clone the numpyro git repository and run pip install
3226RUN git clone https://github.com/pyro-ppl/numpyro.git && \
You can’t perform that action at this time.
0 commit comments