Skip to content

Commit 702773e

Browse files
committed
fix: JAX version pinning
1 parent 66013bc commit 702773e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
run: |
3434
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
3535
pip install pyro-ppl
36-
pip install --upgrade "jax[cuda12-local]==0.6.2"
36+
pip install --upgrade "jax[cuda12-local]==0.8.0"
3737
pip install numpyro pyro-ppl
3838
python scripts/test-jax-install.py
3939
- name: Check nvidia Drivers

0 commit comments

Comments
 (0)