Skip to content

Commit d0cbac5

Browse files
mgoinHarry-Chen
andauthored
[Dev UX] Add auto-detection for VLLM_PRECOMPILED_WHEEL_VARIANT during install (vllm-project#32948)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Michael Goin <[email protected]> Co-authored-by: Shengqi Chen <[email protected]>
1 parent c0d8204 commit d0cbac5

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

docs/getting_started/installation/gpu.cuda.inc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ There are more environment variables to control the behavior of Python-only buil
118118

119119
* `VLLM_PRECOMPILED_WHEEL_LOCATION`: specify the exact wheel URL or local file path of a pre-compiled wheel to use. All other logic to find the wheel will be skipped.
120120
* `VLLM_PRECOMPILED_WHEEL_COMMIT`: override the commit hash to download the pre-compiled wheel. It can be `nightly` to use the last **already built** commit on the main branch.
121-
* `VLLM_PRECOMPILED_WHEEL_VARIANT`: specify the variant subdirectory to use on the nightly index, e.g., `cu129`, `cpu`. If not specified, the CUDA variant with `VLLM_MAIN_CUDA_VERSION` will be tried, then fallback to the default variant on the remote index.
121+
* `VLLM_PRECOMPILED_WHEEL_VARIANT`: specify the variant subdirectory to use on the nightly index, e.g., `cu129`, `cu130`, `cpu`. If not specified, the variant is auto-detected based on your system's CUDA version (from PyTorch or nvidia-smi). You can also set `VLLM_MAIN_CUDA_VERSION` to override auto-detection.
122122

123123
You can find more information about vLLM's wheels in [Install the latest code](#install-the-latest-code).
124124

setup.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,49 @@ def is_rocm_system() -> bool:
438438
except ImportError:
439439
return False
440440

441+
@staticmethod
442+
def detect_system_cuda_variant() -> str:
443+
"""Auto-detect CUDA variant from torch, nvidia-smi, or env default."""
444+
445+
# Map CUDA major version to hosted wheel variants on wheels.vllm.ai
446+
supported = {12: "cu129", 13: "cu130"}
447+
448+
# Respect explicitly set VLLM_MAIN_CUDA_VERSION
449+
if envs.is_set("VLLM_MAIN_CUDA_VERSION"):
450+
v = envs.VLLM_MAIN_CUDA_VERSION
451+
print(f"Using VLLM_MAIN_CUDA_VERSION={v}")
452+
return "cu" + v.replace(".", "")[:3]
453+
454+
# Try torch.version.cuda
455+
cuda_version = None
456+
try:
457+
import torch
458+
459+
cuda_version = torch.version.cuda
460+
except Exception:
461+
pass
462+
463+
# Try nvidia-smi
464+
if not cuda_version:
465+
try:
466+
out = subprocess.run(
467+
["nvidia-smi"], capture_output=True, text=True, timeout=10
468+
)
469+
if m := re.search(r"CUDA Version:\s*(\d+\.\d+)", out.stdout):
470+
cuda_version = m.group(1)
471+
except Exception:
472+
pass
473+
474+
# Fall back to default
475+
if not cuda_version:
476+
cuda_version = envs.VLLM_MAIN_CUDA_VERSION
477+
478+
# Map to supported variant
479+
major = int(cuda_version.split(".")[0])
480+
variant = supported.get(major, supported[max(supported)])
481+
print(f"Detected CUDA {cuda_version}, using variant {variant}")
482+
return variant
483+
441484
@staticmethod
442485
def find_local_rocm_wheel() -> str | None:
443486
"""Search for a local vllm wheel in common locations."""
@@ -513,8 +556,8 @@ def determine_wheel_url() -> tuple[str, str | None]:
513556
1. user-specified wheel location (can be either local or remote, via
514557
VLLM_PRECOMPILED_WHEEL_LOCATION)
515558
2. user-specified variant (VLLM_PRECOMPILED_WHEEL_VARIANT) from nightly repo
516-
3. the variant corresponding to VLLM_MAIN_CUDA_VERSION from nightly repo
517-
4. the default variant from nightly repo
559+
or auto-detected CUDA variant based on system (torch, nvidia-smi)
560+
3. the default variant from nightly repo
518561
519562
If downloading from the nightly repo, the commit can be specified via
520563
VLLM_PRECOMPILED_WHEEL_COMMIT; otherwise, the head commit in the main branch
@@ -533,9 +576,11 @@ def determine_wheel_url() -> tuple[str, str | None]:
533576
import platform
534577

535578
arch = platform.machine()
536-
# try to fetch the wheel metadata from the nightly wheel repo
537-
main_variant = "cu" + envs.VLLM_MAIN_CUDA_VERSION.replace(".", "")
538-
variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", main_variant)
579+
# try to fetch the wheel metadata from the nightly wheel repo,
580+
# detecting CUDA variant from system if not specified
581+
variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", None)
582+
if variant is None:
583+
variant = precompiled_wheel_utils.detect_system_cuda_variant()
539584
commit = os.getenv("VLLM_PRECOMPILED_WHEEL_COMMIT", "").lower()
540585
if not commit or len(commit) != 40:
541586
print(

0 commit comments

Comments
 (0)