Skip to content

Commit 4f884bf

Browse files
hotfix: make aot wheel work without nvcc (#1782)
<!-- .github/pull_request_template.md --> ## 📌 Description Change the logic of `get_cuda_version`: when nvcc is not installed, return torch.version.cuda, this patch could unblock #1781 . ## 🔍 Related Issues #1781 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @Flynn-Zh --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent db7af92 commit 4f884bf

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

flashinfer/jit/cpp_ext.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,23 @@ def get_cuda_version() -> Version:
5555
nvcc = "nvcc"
5656
else:
5757
nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
58-
txt = subprocess.check_output([nvcc, "--version"], text=True)
59-
matches = re.findall(r"release (\d+\.\d+),", txt)
60-
if not matches:
61-
raise RuntimeError(
62-
f"Could not parse CUDA version from nvcc --version output: {txt}"
63-
)
64-
return Version(matches[0])
58+
# Try to query nvcc for CUDA version; if nvcc is unavailable, fall back to torch.version.cuda
59+
try:
60+
txt = subprocess.check_output([nvcc, "--version"], text=True)
61+
matches = re.findall(r"release (\d+\.\d+),", txt)
62+
if not matches:
63+
raise RuntimeError(
64+
f"Could not parse CUDA version from nvcc --version output: {txt}"
65+
)
66+
return Version(matches[0])
67+
except (FileNotFoundError, subprocess.CalledProcessError) as e:
68+
# NOTE(Zihao): when nvcc is unavailable, fall back to torch.version.cuda
69+
if torch.version.cuda is None:
70+
raise RuntimeError(
71+
"nvcc not found and PyTorch is not built with CUDA support. "
72+
"Could not determine CUDA version."
73+
) from e
74+
return Version(torch.version.cuda)
6575

6676

6777
def is_cuda_version_at_least(version_str: str) -> bool:

0 commit comments

Comments
 (0)