@@ -55,6 +55,14 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
5555 return UMF_RESULT_ERROR_NOT_SUPPORTED ;
5656}
5757
58+ umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags (
59+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
60+ (void )hParams ;
61+ (void )flags ;
62+ LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
63+ return UMF_RESULT_ERROR_NOT_SUPPORTED ;
64+ }
65+
5866umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
5967 // not supported
6068 LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
@@ -89,13 +97,22 @@ typedef struct cu_memory_provider_t {
8997 CUdevice device ;
9098 umf_usm_memory_type_t memory_type ;
9199 size_t min_alignment ;
100+ unsigned int alloc_flags ;
92101} cu_memory_provider_t ;
93102
94103// CUDA Memory Provider settings struct
95104typedef struct umf_cuda_memory_provider_params_t {
96- void * cuda_context_handle ; ///< Handle to the CUDA context
97- int cuda_device_handle ; ///< Handle to the CUDA device
98- umf_usm_memory_type_t memory_type ; ///< Allocation memory type
105+ // Handle to the CUDA context
106+ void * cuda_context_handle ;
107+
108+ // Handle to the CUDA device
109+ int cuda_device_handle ;
110+
111+ // Allocation memory type
112+ umf_usm_memory_type_t memory_type ;
113+
114+ // Allocation flags for cuMemHostAlloc/cuMemAllocManaged
115+ unsigned int alloc_flags ;
99116} umf_cuda_memory_provider_params_t ;
100117
101118typedef struct cu_ops_t {
@@ -104,6 +121,7 @@ typedef struct cu_ops_t {
104121 CUmemAllocationGranularity_flags option );
105122 CUresult (* cuMemAlloc )(CUdeviceptr * dptr , size_t bytesize );
106123 CUresult (* cuMemAllocHost )(void * * pp , size_t bytesize );
124+ CUresult (* cuMemHostAlloc )(void * * pp , size_t bytesize , unsigned int flags );
107125 CUresult (* cuMemAllocManaged )(CUdeviceptr * dptr , size_t bytesize ,
108126 unsigned int flags );
109127 CUresult (* cuMemFree )(CUdeviceptr dptr );
@@ -175,6 +193,8 @@ static void init_cu_global_state(void) {
175193 utils_get_symbol_addr (0 , "cuMemAlloc_v2" , lib_name );
176194 * (void * * )& g_cu_ops .cuMemAllocHost =
177195 utils_get_symbol_addr (0 , "cuMemAllocHost_v2" , lib_name );
196+ * (void * * )& g_cu_ops .cuMemHostAlloc =
197+ utils_get_symbol_addr (0 , "cuMemHostAlloc" , lib_name );
178198 * (void * * )& g_cu_ops .cuMemAllocManaged =
179199 utils_get_symbol_addr (0 , "cuMemAllocManaged" , lib_name );
180200 * (void * * )& g_cu_ops .cuMemFree =
@@ -197,12 +217,12 @@ static void init_cu_global_state(void) {
197217 utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
198218
199219 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
200- !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
201- !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
202- !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
203- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
204- !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
205- !g_cu_ops .cuIpcCloseMemHandle ) {
220+ !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemHostAlloc ||
221+ !g_cu_ops .cuMemAllocManaged || !g_cu_ops .cuMemFree ||
222+ !g_cu_ops .cuMemFreeHost || !g_cu_ops .cuGetErrorName ||
223+ !g_cu_ops .cuGetErrorString || !g_cu_ops .cuCtxGetCurrent ||
224+ !g_cu_ops .cuCtxSetCurrent || !g_cu_ops .cuIpcGetMemHandle ||
225+ !g_cu_ops .cuIpcOpenMemHandle || ! g_cu_ops . cuIpcCloseMemHandle ) {
206226 LOG_ERR ("Required CUDA symbols not found." );
207227 Init_cu_global_state_failed = true;
208228 }
@@ -226,6 +246,7 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
226246 params_data -> cuda_context_handle = NULL ;
227247 params_data -> cuda_device_handle = -1 ;
228248 params_data -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
249+ params_data -> alloc_flags = 0 ;
229250
230251 * hParams = params_data ;
231252
@@ -276,6 +297,18 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
276297 return UMF_RESULT_SUCCESS ;
277298}
278299
300+ umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags (
301+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
302+ if (!hParams ) {
303+ LOG_ERR ("CUDA Memory Provider params handle is NULL" );
304+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
305+ }
306+
307+ hParams -> alloc_flags = flags ;
308+
309+ return UMF_RESULT_SUCCESS ;
310+ }
311+
279312static umf_result_t cu_memory_provider_initialize (void * params ,
280313 void * * provider ) {
281314 if (params == NULL ) {
@@ -295,6 +328,24 @@ static umf_result_t cu_memory_provider_initialize(void *params,
295328 return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
296329 }
297330
331+ if (cu_params -> memory_type == UMF_MEMORY_TYPE_SHARED ) {
332+ if (cu_params -> alloc_flags == 0 ) {
333+ // if flags are not set, the default setting is CU_MEM_ATTACH_GLOBAL
334+ cu_params -> alloc_flags = CU_MEM_ATTACH_GLOBAL ;
335+ } else if (cu_params -> alloc_flags != CU_MEM_ATTACH_GLOBAL &&
336+ cu_params -> alloc_flags != CU_MEM_ATTACH_HOST ) {
337+ LOG_ERR ("Invalid shared allocation flags" );
338+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
339+ }
340+ } else if (cu_params -> memory_type == UMF_MEMORY_TYPE_HOST ) {
341+ if (cu_params -> alloc_flags &
342+ ~(CU_MEMHOSTALLOC_PORTABLE | CU_MEMHOSTALLOC_DEVICEMAP |
343+ CU_MEMHOSTALLOC_WRITECOMBINED )) {
344+ LOG_ERR ("Invalid host allocation flags" );
345+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
346+ }
347+ }
348+
298349 utils_init_once (& cu_is_initialized , init_cu_global_state );
299350 if (Init_cu_global_state_failed ) {
300351 LOG_ERR ("Loading CUDA symbols failed" );
@@ -325,6 +376,7 @@ static umf_result_t cu_memory_provider_initialize(void *params,
325376 cu_provider -> device = cu_params -> cuda_device_handle ;
326377 cu_provider -> memory_type = cu_params -> memory_type ;
327378 cu_provider -> min_alignment = min_alignment ;
379+ cu_provider -> alloc_flags = cu_params -> alloc_flags ;
328380
329381 * provider = cu_provider ;
330382
@@ -382,7 +434,8 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
382434 CUresult cu_result = CUDA_SUCCESS ;
383435 switch (cu_provider -> memory_type ) {
384436 case UMF_MEMORY_TYPE_HOST : {
385- cu_result = g_cu_ops .cuMemAllocHost (resultPtr , size );
437+ cu_result =
438+ g_cu_ops .cuMemHostAlloc (resultPtr , size , cu_provider -> alloc_flags );
386439 break ;
387440 }
388441 case UMF_MEMORY_TYPE_DEVICE : {
@@ -391,7 +444,7 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
391444 }
392445 case UMF_MEMORY_TYPE_SHARED : {
393446 cu_result = g_cu_ops .cuMemAllocManaged ((CUdeviceptr * )resultPtr , size ,
394- CU_MEM_ATTACH_GLOBAL );
447+ cu_provider -> alloc_flags );
395448 break ;
396449 }
397450 default :
0 commit comments