Skip to content

Commit 7d7aa87

Browse files
authored
fix: CI containers install nvidia-cudnn-cu12 vs. nvidia-cudnn-cu13 based on CUDA Version (#1742)
<!-- .github/pull_request_template.md --> ## 📌 Description Current `flashinfer-ci-cu[12x, 130]` containers all install [nvidia-cudnn-cu12](https://pypi.org/project/nvidia-cudnn-cu12/). However, in CUDA 13 environments, [nvidia-cudnn-cu13](https://pypi.org/project/nvidia-cudnn-cu13/) should be installed. The PR modifies `install_python_packages.sh` such that if CUDA 13 is used, `nvidia-cudnn-cu13>=9.12.0.46` is installed (note that [9.12.0.46](https://pypi.org/project/nvidia-cudnn-cu13/#history) is the earliest version) <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 6832b46 commit 7d7aa87

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

docker/install/install_python_packages.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,15 @@ CUDA_VERSION=${1:-cu128}
2525
pip3 install torch --index-url https://download.pytorch.org/whl/${CUDA_VERSION}
2626
pip3 install requests ninja pytest numpy scipy build nvidia-ml-py cuda-python einops nvidia-nvshmem-cu12
2727
pip3 install nvidia-cutlass-dsl
28-
pip3 install 'nvidia-cudnn-frontend>=1.13.0' 'nvidia-cudnn-cu12>=9.11.0.98'
28+
pip3 install 'nvidia-cudnn-frontend>=1.13.0'
29+
30+
# Install cudnn package based on CUDA version
31+
if [[ "$CUDA_VERSION" == *"cu13"* ]]; then
32+
CUDNN_PACKAGE="nvidia-cudnn-cu13>=9.12.0.46"
33+
else
34+
CUDNN_PACKAGE="nvidia-cudnn-cu12>=9.11.0.98"
35+
fi
36+
37+
if [[ -n "$CUDNN_PACKAGE" ]]; then
38+
pip3 install $CUDNN_PACKAGE
39+
fi

0 commit comments

Comments
 (0)