Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
runs-on: nvidia
# https://github.com/deepmodeling/deepmd-kit/pull/2884#issuecomment-1744216845
container:
image: nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04
image: nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04
options: --gpus all
if: github.repository_owner == 'deepmodeling' && (github.event_name == 'pull_request' && github.event.label && github.event.label.name == 'Test CUDA' || github.event_name == 'workflow_dispatch' || github.event_name == 'merge_group')
steps:
Expand All @@ -36,12 +36,6 @@ jobs:
with:
useLocalCache: true
useCloudCache: false
- run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb \
&& sudo dpkg -i cuda-keyring_1.0-1_all.deb \
&& sudo apt-get update \
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
if: false # skip as we use nvidia image
- run: python -m pip install -U uv
- run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax "jax[cuda12]"
- run: |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ pin_pytorch_gpu = [
"torch>=2.7,<2.10",
]
pin_jax = [
"jax==0.5.0;python_version>='3.10'",
"jax==0.6.2;python_version>='3.10'",
]

[tool.setuptools_scm]
Expand Down