Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def __init__(
checkCudaErrors(cuda.cuCtxSetCurrent(primary_ctx))

# Set CUDA device
import cuda.cudart as cudart
import cuda.bindings.runtime as cudart

checkCudaErrors(cudart.cudaSetDevice(device_idx))

Expand Down
7 changes: 2 additions & 5 deletions flashinfer/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
limitations under the License.
"""

import cuda.bindings.driver as driver
import cuda.bindings.runtime as runtime
import cuda.cudart as cudart
import cuda.nvrtc as nvrtc
from cuda.bindings import driver, nvrtc, runtime


def _cudaGetErrorEnum(error):
if isinstance(error, driver.CUresult):
err, name = driver.cuGetErrorName(error)
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, runtime.cudaError_t):
return cudart.cudaGetErrorName(error)[1]
return runtime.cudaGetErrorName(error)[1]
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate_build_meta(aot_build_meta: dict) -> None:
"torch",
"ninja",
"requests",
"cuda-python<=12.9",
"cuda-python",
"pynvml",
"einops",
"nvidia-cudnn-frontend>=1.13.0",
Expand Down