Skip to content

Commit f652c69

Browse files
committed
Fix JAX CUDA installation to use cuda12_pip instead of cuda12-local
Changes jax[cuda12-local] to jax[cuda12_pip] to avoid cuDNN version compatibility issues. The cuda12_pip variant includes compatible CUDA and cuDNN libraries bundled with JAX, preventing runtime errors from mismatched local CUDA installations. Fixes cuDNN 9.10.0 backward-compatibility error.
1 parent 90f6a6f commit f652c69

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

.github/workflows/cache.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Install JAX and Numpyro
2424
shell: bash -l {0}
2525
run: |
26-
pip install --upgrade "jax[cuda12-local]==0.6.2"
26+
pip install --upgrade "jax[cuda12_pip]==0.6.2"
2727
pip install numpyro
2828
python scripts/test-jax-install.py
2929
- name: Check nvidia drivers

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- name: Install JAX and Numpyro
3232
shell: bash -l {0}
3333
run: |
34-
pip install --upgrade "jax[cuda12-local]"
34+
pip install --upgrade "jax[cuda12_pip]"
3535
pip install numpyro
3636
python scripts/test-jax-install.py
3737
- name: Check nvidia Drivers

.github/workflows/publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
- name: Install JAX and Numpyro
2323
shell: bash -l {0}
2424
run: |
25-
pip install --upgrade "jax[cuda12-local]==0.6.2"
25+
pip install --upgrade "jax[cuda12_pip]==0.6.2"
2626
pip install numpyro
2727
python scripts/test-jax-install.py
2828
- name: Check nvidia drivers

0 commit comments

Comments
 (0)