Skip to content

Commit 727bcf6

Browse files
committed
Update JAX to use cuda13-local for CUDA 13 support
Changes all workflows from cuda12-local to cuda13-local to match the CUDA 13 installation in the new AMI (ami-0edec81935264b6d3).
1 parent 0e65df5 commit 727bcf6

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

.github/workflows/cache.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Install JAX and Numpyro
2424
shell: bash -l {0}
2525
run: |
26-
pip install --upgrade "jax[cuda12-local]"
26+
pip install --upgrade "jax[cuda13-local]"
2727
pip install numpyro
2828
python scripts/test-jax-install.py
2929
- name: Check nvidia drivers

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- name: Install JAX and Numpyro
3232
shell: bash -l {0}
3333
run: |
34-
pip install --upgrade "jax[cuda12-local]"
34+
pip install --upgrade "jax[cuda13-local]"
3535
pip install numpyro
3636
python scripts/test-jax-install.py
3737
- name: Check nvidia Drivers

.github/workflows/publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
- name: Install JAX and Numpyro
2323
shell: bash -l {0}
2424
run: |
25-
pip install --upgrade "jax[cuda12-local]"
25+
pip install --upgrade "jax[cuda13-local]"
2626
pip install numpyro
2727
python scripts/test-jax-install.py
2828
- name: Check nvidia drivers

0 commit comments

Comments
 (0)