Skip to content

Commit f129cff

Browse files
authored
ci: install GPU JAX in GPU CI (#4293)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Chores** - Updated the workflow configuration for testing CUDA to improve efficiency and concurrency. - Added a new package for enhanced environment setup in CUDA testing. - Introduced an environment variable to optimize GPU memory allocation during tests. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent fb41a4f commit f129cff

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

.github/workflows/test_cuda.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
4848
if: false # skip as we use nvidia image
4949
- run: python -m pip install -U uv
50-
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0"
50+
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" "jax[cuda12]"
5151
- run: |
5252
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
5353
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
@@ -61,6 +61,8 @@ jobs:
6161
env:
6262
NUM_WORKERS: 0
6363
CUDA_VISIBLE_DEVICES: 0
64+
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
65+
XLA_PYTHON_CLIENT_PREALLOCATE: false
6466
- name: Download libtorch
6567
run: |
6668
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip

0 commit comments

Comments
 (0)