Skip to content

Commit ae215a6

Browse files
limin2021claude
andcommitted
fix: auto-detect CUDA version for nvidia-cutlass-dsl extra and clean install
- requirements.txt: remove [cu13] extra (keep generic for Docker builds) - setup_test_env.sh: detect CUDA major version from torch to select [cu12] or [cu13] extra, and clean uninstall old packages before installing (per NVIDIA docs recommendation) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent aa17f54 commit ae215a6

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ einops
55
ninja
66
numpy
77
nvidia-cudnn-frontend>=1.13.0
8-
nvidia-cutlass-dsl[cu13]>=4.4.2
8+
nvidia-cutlass-dsl>=4.4.2
99
nvidia-ml-py
1010
packaging>=24.2
1111
requests

scripts/setup_test_env.sh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,19 @@ fi
2626

2727
# Override nvidia-cutlass-dsl if specified
2828
if [ -n "${CUTLASS_DSL_VERSION:-}" ]; then
29+
# Detect CUDA major version to select the correct extra (cu12 or cu13)
30+
CUDA_MAJOR=$(python -c "import torch; print(torch.version.cuda.split('.')[0])" 2>/dev/null || echo "12")
31+
if [ "$CUDA_MAJOR" = "13" ]; then
32+
CUTLASS_DSL_EXTRA="cu13"
33+
else
34+
CUTLASS_DSL_EXTRA="cu12"
35+
fi
2936
echo "========================================"
30-
echo "Overriding nvidia-cutlass-dsl with version: ${CUTLASS_DSL_VERSION}"
37+
echo "Overriding nvidia-cutlass-dsl with version: ${CUTLASS_DSL_VERSION} [${CUTLASS_DSL_EXTRA}]"
3138
echo "========================================"
32-
pip install --force-reinstall "nvidia-cutlass-dsl[cu13]==${CUTLASS_DSL_VERSION}"
39+
# Clean uninstall old packages first (recommended by NVIDIA docs)
40+
pip uninstall nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base nvidia-cutlass-dsl-libs-cu12 nvidia-cutlass-dsl-libs-cu13 -y 2>/dev/null || true
41+
pip install "nvidia-cutlass-dsl[${CUTLASS_DSL_EXTRA}]==${CUTLASS_DSL_VERSION}"
3342
echo "nvidia-cutlass-dsl override complete."
3443
echo ""
3544
fi

0 commit comments

Comments
 (0)