Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion cpm_kernels/library/cudart.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@
cudaMemcpyDefault = 4
# Direction of the transfer is inferred from the pointer values. Requires unified virtual addressing

# For version compatible
MALLOC_AYNC_SUPPORT = False
GET_FUNC_BY_SYMBOL_SUPPORT = False


class dim3(ctypes.Structure):
_fields_ = [
('x', ctypes.c_uint),
Expand Down Expand Up @@ -340,6 +345,8 @@ def cudaDriverGetVersion() -> int:

try:
version = cudaRuntimeGetVersion()
if version >= 11200: MALLOC_AYNC_SUPPORT = True
if version >= 11000: GET_FUNC_BY_SYMBOL_SUPPORT = True
except RuntimeError:
version = 0

Expand Down Expand Up @@ -379,6 +386,17 @@ def cudaMalloc(size : int) -> ctypes.c_void_p:
def cudaFree(ptr : ctypes.c_void_p) -> None:
checkCUDAStatus(cuda.cudaFree(ptr))

if MALLOC_AYNC_SUPPORT:
@cuda.bind("cudaMallocAsync", [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t, cudaStream_t], cudaError_t)
def cudaMallocAsync(size : int, stream : cudaStream_t) -> ctypes.c_void_p:
ptr = ctypes.c_void_p()
checkCUDAStatus(cuda.cudaMallocAsync(ctypes.byref(ptr), size, stream))
return ptr

@cuda.bind("cudaFreeAsync", [ctypes.c_void_p, cudaStream_t], cudaError_t)
def cudaFreeAsync(ptr : ctypes.c_void_p, stream : cudaStream_t) -> None:
checkCUDAStatus(cuda.cudaFreeAsync(ptr, stream))

@cuda.bind("cudaMallocHost", [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t], cudaError_t)
def cudaMallocHost(size : int) -> ctypes.c_void_p:
ptr = ctypes.c_void_p()
Expand Down Expand Up @@ -474,7 +492,7 @@ def cudaLaunchKernel(
kernelParams = None
checkCUDAStatus(cuda.cudaLaunchKernel(func, gridDim, blockDim, kernelParams, sharedMem, stream))

if version >= 11000:
if GET_FUNC_BY_SYMBOL_SUPPORT:
@cuda.bind("cudaGetFuncBySymbol", [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p], cudaError_t)
def cudaGetFuncBySymbol(func : ctypes.c_void_p) -> ctypes.c_void_p:
ret = ctypes.c_void_p()
Expand Down