diff --git a/.github/workflows/cache.yml b/.github/workflows/cache.yml index c499a1599..ec7536a5b 100644 --- a/.github/workflows/cache.yml +++ b/.github/workflows/cache.yml @@ -24,7 +24,7 @@ jobs: shell: bash -l {0} run: | pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 - pip install --upgrade "jax[cuda12-local]" + pip install --upgrade "jax[cuda12-local]==0.6.2" pip install numpyro python scripts/test-jax-install.py - name: Check nvidia drivers diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index b0474f4a3..b0a9a6c8f 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -23,7 +23,7 @@ jobs: shell: bash -l {0} run: | pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 - pip install --upgrade "jax[cuda12-local]" + pip install --upgrade "jax[cuda12-local]==0.6.2" pip install numpyro python scripts/test-jax-install.py - name: Check nvidia drivers