Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,15 @@ def __init__(
checkCudaErrors(cuda.cuCtxSetCurrent(primary_ctx))

# Set CUDA device
import cuda.cudart as cudart
# Check if cuda.cudart is available and import accordingly
from flashinfer.utils import has_cuda_cudart

if has_cuda_cudart():
# cuda-python <= 12.9
import cuda.cudart as cudart
else:
# cuda-python >= 13.0
import cuda.bindings.runtime as cudart

checkCudaErrors(cudart.cudaSetDevice(device_idx))

Expand Down
18 changes: 14 additions & 4 deletions flashinfer/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,20 @@
limitations under the License.
"""

import cuda.bindings.driver as driver
import cuda.bindings.runtime as runtime
import cuda.cudart as cudart
import cuda.nvrtc as nvrtc
from flashinfer.utils import has_cuda_cudart

# Check if cuda.cudart module is available and import accordingly
if has_cuda_cudart():
# cuda-python <= 12.9 (has cuda.cudart)
import cuda.bindings.driver as driver
import cuda.bindings.runtime as runtime
import cuda.cudart as cudart
import cuda.nvrtc as nvrtc
else:
# cuda-python >= 13.0 (no cuda.cudart, use runtime as cudart)
from cuda.bindings import driver, nvrtc, runtime

cudart = runtime # Alias runtime as cudart for compatibility


def _cudaGetErrorEnum(error):
Expand Down
12 changes: 12 additions & 0 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,18 @@ def version_at_least(version: str, base_version: str) -> bool:
return pkg_version.parse(version) >= pkg_version.parse(base_version)


def has_cuda_cudart() -> bool:
"""
Check if cuda.cudart module is available (cuda-python <= 12.9).

Returns:
True if cuda.cudart exists, False otherwise
"""
import importlib.util

return importlib.util.find_spec("cuda.cudart") is not None


def is_sm90a_supported(device: torch.device) -> bool:
major, _ = get_compute_capability(device)
return major == 9 and version_at_least(torch.version.cuda, "12.3")
Expand Down