Fail to import jax for jaxlib compatible with CUDA 10.2 #9075
-
I installed jax compatible with CUDA 10.2 by running: However, jax raises an error that the jaxlib version is incompatible whenever I try to import jax >>> import jax
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/ayrton/anaconda3/envs/proj/lib/python3.8/site-packages/jax/__init__.py", line 37, in <module>
from jax import config as _config_module
File "/home/ayrton/anaconda3/envs/proj/lib/python3.8/site-packages/jax/config.py", line 18, in <module>
from jax._src.config import config
File "/home/ayrton/anaconda3/envs/proj/lib/python3.8/site-packages/jax/_src/config.py", line 27, in <module>
from jax._src import lib
File "/home/ayrton/anaconda3/envs/proj/lib/python3.8/site-packages/jax/_src/lib/__init__.py", line 69, in <module>
_check_jaxlib_version()
File "/home/ayrton/anaconda3/envs/proj/lib/python3.8/site-packages/jax/_src/lib/__init__.py", line 67, in _check_jaxlib_version
raise ValueError(msg)
ValueError: jaxlib is version 0.1.70, but this version of jax requires version 0.1.74. Installing the latest jaxlib (0.1.75) solves the issue, but prevents me from using a GPU since it only works for >= CUDA 11. I appreciate the help anyone can give. Thank you. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I'm also constrained to CUDA 10, and needed to install |
Beta Was this translation helpful? Give feedback.
I'm also constrained to CUDA 10, and needed to install
jax==0.2.25
manually. This is the last jax version that supports jaxlib 0.1.70 (See https://github.com/google/jax/blob/jax-v0.2.25/jax/version.py and then https://github.com/google/jax/blob/jax-v0.2.26/jax/version.py). It seems pip just doesn't know how to figure out the right version (it doesn't download all older versions to figure out which are compatible), so you'll need to specify the jax version explicitly.