Skip to content

Commit 0e65df5

Browse files
committed
Use cuda12-local for JAX installation across all workflows
Changes all workflows to use jax[cuda12-local] to leverage the CUDA and cuDNN libraries pre-installed in the new AMI (ami-0edec81935264b6d3). This is faster than cuda12_pip and uses the system libraries. Also removes version pin (==0.6.2) from cache.yml and publish.yml to make all workflows consistent.
1 parent a688344 commit 0e65df5

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

.github/workflows/cache.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Install JAX and Numpyro
2424
shell: bash -l {0}
2525
run: |
26-
pip install --upgrade "jax[cuda12_pip]==0.6.2"
26+
pip install --upgrade "jax[cuda12-local]"
2727
pip install numpyro
2828
python scripts/test-jax-install.py
2929
- name: Check nvidia drivers

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- name: Install JAX and Numpyro
3232
shell: bash -l {0}
3333
run: |
34-
pip install --upgrade "jax[cuda12_pip]"
34+
pip install --upgrade "jax[cuda12-local]"
3535
pip install numpyro
3636
python scripts/test-jax-install.py
3737
- name: Check nvidia Drivers

.github/workflows/publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
- name: Install JAX and Numpyro
2323
shell: bash -l {0}
2424
run: |
25-
pip install --upgrade "jax[cuda12_pip]==0.6.2"
25+
pip install --upgrade "jax[cuda12-local]"
2626
pip install numpyro
2727
python scripts/test-jax-install.py
2828
- name: Check nvidia drivers

0 commit comments

Comments
 (0)