Skip to content

Commit b741065

Browse files
authored
bugfix: Fix cuda-python v13.0 import compatibility (#1455)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> cc: @yzh119 @paul841029
1 parent 60ae234 commit b741065

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,15 @@ def __init__(
517517
checkCudaErrors(cuda.cuCtxSetCurrent(primary_ctx))
518518

519519
# Set CUDA device
520-
import cuda.cudart as cudart
520+
# Check if cuda.cudart is available and import accordingly
521+
from flashinfer.utils import has_cuda_cudart
522+
523+
if has_cuda_cudart():
524+
# cuda-python <= 12.9
525+
import cuda.cudart as cudart
526+
else:
527+
# cuda-python >= 13.0
528+
import cuda.bindings.runtime as cudart
521529

522530
checkCudaErrors(cudart.cudaSetDevice(device_idx))
523531

flashinfer/cuda_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,20 @@
1414
limitations under the License.
1515
"""
1616

17-
import cuda.bindings.driver as driver
18-
import cuda.bindings.runtime as runtime
19-
import cuda.cudart as cudart
20-
import cuda.nvrtc as nvrtc
17+
from flashinfer.utils import has_cuda_cudart
18+
19+
# Check if cuda.cudart module is available and import accordingly
20+
if has_cuda_cudart():
21+
# cuda-python <= 12.9 (has cuda.cudart)
22+
import cuda.bindings.driver as driver
23+
import cuda.bindings.runtime as runtime
24+
import cuda.cudart as cudart
25+
import cuda.nvrtc as nvrtc
26+
else:
27+
# cuda-python >= 13.0 (no cuda.cudart, use runtime as cudart)
28+
from cuda.bindings import driver, nvrtc, runtime
29+
30+
cudart = runtime # Alias runtime as cudart for compatibility
2131

2232

2333
def _cudaGetErrorEnum(error):

flashinfer/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ def version_at_least(version: str, base_version: str) -> bool:
417417
return pkg_version.parse(version) >= pkg_version.parse(base_version)
418418

419419

420+
def has_cuda_cudart() -> bool:
421+
"""
422+
Check if cuda.cudart module is available (cuda-python <= 12.9).
423+
424+
Returns:
425+
True if cuda.cudart exists, False otherwise
426+
"""
427+
import importlib.util
428+
429+
return importlib.util.find_spec("cuda.cudart") is not None
430+
431+
420432
def is_sm90a_supported(device: torch.device) -> bool:
421433
major, _ = get_compute_capability(device)
422434
return major == 9 and version_at_least(torch.version.cuda, "12.3")

0 commit comments

Comments
 (0)