diff --git a/cpm_kernels/library/cudart.py b/cpm_kernels/library/cudart.py index 477e6bf..2162ab4 100644 --- a/cpm_kernels/library/cudart.py +++ b/cpm_kernels/library/cudart.py @@ -298,6 +298,11 @@ cudaMemcpyDefault = 4 # Direction of the transfer is inferred from the pointer values. Requires unified virtual addressing +# For version compatible +MALLOC_AYNC_SUPPORT = False +GET_FUNC_BY_SYMBOL_SUPPORT = False + + class dim3(ctypes.Structure): _fields_ = [ ('x', ctypes.c_uint), @@ -340,6 +345,8 @@ def cudaDriverGetVersion() -> int: try: version = cudaRuntimeGetVersion() + if version >= 11200: MALLOC_AYNC_SUPPORT = True + if version >= 11000: GET_FUNC_BY_SYMBOL_SUPPORT = True except RuntimeError: version = 0 @@ -379,6 +386,17 @@ def cudaMalloc(size : int) -> ctypes.c_void_p: def cudaFree(ptr : ctypes.c_void_p) -> None: checkCUDAStatus(cuda.cudaFree(ptr)) +if MALLOC_AYNC_SUPPORT: + @cuda.bind("cudaMallocAsync", [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t, cudaStream_t], cudaError_t) + def cudaMallocAsync(size : int, stream : cudaStream_t) -> ctypes.c_void_p: + ptr = ctypes.c_void_p() + checkCUDAStatus(cuda.cudaMallocAsync(ctypes.byref(ptr), size, stream)) + return ptr + + @cuda.bind("cudaFreeAsync", [ctypes.c_void_p, cudaStream_t], cudaError_t) + def cudaFreeAsync(ptr : ctypes.c_void_p, stream : cudaStream_t) -> None: + checkCUDAStatus(cuda.cudaFreeAsync(ptr, stream)) + @cuda.bind("cudaMallocHost", [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t], cudaError_t) def cudaMallocHost(size : int) -> ctypes.c_void_p: ptr = ctypes.c_void_p() @@ -474,7 +492,7 @@ def cudaLaunchKernel( kernelParams = None checkCUDAStatus(cuda.cudaLaunchKernel(func, gridDim, blockDim, kernelParams, sharedMem, stream)) -if version >= 11000: +if GET_FUNC_BY_SYMBOL_SUPPORT: @cuda.bind("cudaGetFuncBySymbol", [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p], cudaError_t) def cudaGetFuncBySymbol(func : ctypes.c_void_p) -> ctypes.c_void_p: ret = ctypes.c_void_p()