@@ -55,6 +55,22 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
5555 return UMF_RESULT_ERROR_NOT_SUPPORTED ;
5656}
5757
58+ umf_result_t umfCUDAMemoryProviderParamsSetHostAllocFlags (
59+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
60+ (void )hParams ;
61+ (void )flags ;
62+ LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
63+ return UMF_RESULT_ERROR_NOT_SUPPORTED ;
64+ }
65+
66+ umf_result_t umfCUDAMemoryProviderParamsSetManagedAllocFlags (
67+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
68+ (void )hParams ;
69+ (void )flags ;
70+ LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
71+ return UMF_RESULT_ERROR_NOT_SUPPORTED ;
72+ }
73+
5874umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
5975 // not supported
6076 LOG_ERR ("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!" );
@@ -89,13 +105,17 @@ typedef struct cu_memory_provider_t {
89105 CUdevice device ;
90106 umf_usm_memory_type_t memory_type ;
91107 size_t min_alignment ;
108+ unsigned int host_alloc_flags ;
109+ unsigned int managed_alloc_flags ;
92110} cu_memory_provider_t ;
93111
94112// CUDA Memory Provider settings struct
95113typedef struct umf_cuda_memory_provider_params_t {
96- void * cuda_context_handle ; ///< Handle to the CUDA context
97- int cuda_device_handle ; ///< Handle to the CUDA device
98- umf_usm_memory_type_t memory_type ; ///< Allocation memory type
114+ void * cuda_context_handle ; // Handle to the CUDA context
115+ int cuda_device_handle ; // Handle to the CUDA device
116+ umf_usm_memory_type_t memory_type ; // Allocation memory type
117+ unsigned int host_alloc_flags ; // Allocation flags for cuMemHostAlloc
118+ unsigned int managed_alloc_flags ; // Allocation flags for cuMemAllocManaged
99119} umf_cuda_memory_provider_params_t ;
100120
101121typedef struct cu_ops_t {
@@ -104,6 +124,7 @@ typedef struct cu_ops_t {
104124 CUmemAllocationGranularity_flags option );
105125 CUresult (* cuMemAlloc )(CUdeviceptr * dptr , size_t bytesize );
106126 CUresult (* cuMemAllocHost )(void * * pp , size_t bytesize );
127+ CUresult (* cuMemHostAlloc )(void * * pp , size_t bytesize , unsigned int flags );
107128 CUresult (* cuMemAllocManaged )(CUdeviceptr * dptr , size_t bytesize ,
108129 unsigned int flags );
109130 CUresult (* cuMemFree )(CUdeviceptr dptr );
@@ -175,6 +196,8 @@ static void init_cu_global_state(void) {
175196 utils_get_symbol_addr (0 , "cuMemAlloc_v2" , lib_name );
176197 * (void * * )& g_cu_ops .cuMemAllocHost =
177198 utils_get_symbol_addr (0 , "cuMemAllocHost_v2" , lib_name );
199+ * (void * * )& g_cu_ops .cuMemHostAlloc =
200+ utils_get_symbol_addr (0 , "cuMemHostAlloc" , lib_name );
178201 * (void * * )& g_cu_ops .cuMemAllocManaged =
179202 utils_get_symbol_addr (0 , "cuMemAllocManaged" , lib_name );
180203 * (void * * )& g_cu_ops .cuMemFree =
@@ -197,12 +220,12 @@ static void init_cu_global_state(void) {
197220 utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
198221
199222 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
200- !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
201- !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
202- !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
203- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
204- !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
205- !g_cu_ops .cuIpcCloseMemHandle ) {
223+ !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemHostAlloc ||
224+ !g_cu_ops .cuMemAllocManaged || !g_cu_ops .cuMemFree ||
225+ !g_cu_ops .cuMemFreeHost || !g_cu_ops .cuGetErrorName ||
226+ !g_cu_ops .cuGetErrorString || !g_cu_ops .cuCtxGetCurrent ||
227+ !g_cu_ops .cuCtxSetCurrent || !g_cu_ops .cuIpcGetMemHandle ||
228+ !g_cu_ops .cuIpcOpenMemHandle || ! g_cu_ops . cuIpcCloseMemHandle ) {
206229 LOG_ERR ("Required CUDA symbols not found." );
207230 Init_cu_global_state_failed = true;
208231 }
@@ -226,6 +249,8 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
226249 params_data -> cuda_context_handle = NULL ;
227250 params_data -> cuda_device_handle = -1 ;
228251 params_data -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
252+ params_data -> host_alloc_flags = 0 ;
253+ params_data -> managed_alloc_flags = CU_MEM_ATTACH_GLOBAL ;
229254
230255 * hParams = params_data ;
231256
@@ -276,6 +301,42 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
276301 return UMF_RESULT_SUCCESS ;
277302}
278303
304+ umf_result_t umfCUDAMemoryProviderParamsSetHostAllocFlags (
305+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
306+ if (!hParams ) {
307+ LOG_ERR ("CUDA Memory Provider params handle is NULL" );
308+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
309+ }
310+
311+ // mask out valid flags and check if there are bits left
312+ if (flags & ~(CU_MEMHOSTALLOC_PORTABLE | CU_MEMHOSTALLOC_DEVICEMAP |
313+ CU_MEMHOSTALLOC_WRITECOMBINED )) {
314+ LOG_ERR ("Invalid host allocation flags" );
315+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
316+ }
317+
318+ hParams -> host_alloc_flags = flags ;
319+
320+ return UMF_RESULT_SUCCESS ;
321+ }
322+
323+ umf_result_t umfCUDAMemoryProviderParamsSetManagedAllocFlags (
324+ umf_cuda_memory_provider_params_handle_t hParams , unsigned int flags ) {
325+ if (!hParams ) {
326+ LOG_ERR ("CUDA Memory Provider params handle is NULL" );
327+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
328+ }
329+
330+ if (flags != CU_MEM_ATTACH_GLOBAL && flags != CU_MEM_ATTACH_HOST ) {
331+ LOG_ERR ("Invalid managed allocation flags" );
332+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
333+ }
334+
335+ hParams -> managed_alloc_flags = flags ;
336+
337+ return UMF_RESULT_SUCCESS ;
338+ }
339+
279340static umf_result_t cu_memory_provider_initialize (void * params ,
280341 void * * provider ) {
281342 if (params == NULL ) {
@@ -325,6 +386,8 @@ static umf_result_t cu_memory_provider_initialize(void *params,
325386 cu_provider -> device = cu_params -> cuda_device_handle ;
326387 cu_provider -> memory_type = cu_params -> memory_type ;
327388 cu_provider -> min_alignment = min_alignment ;
389+ cu_provider -> host_alloc_flags = cu_params -> host_alloc_flags ;
390+ cu_provider -> managed_alloc_flags = cu_params -> managed_alloc_flags ;
328391
329392 * provider = cu_provider ;
330393
@@ -382,16 +445,17 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
382445 CUresult cu_result = CUDA_SUCCESS ;
383446 switch (cu_provider -> memory_type ) {
384447 case UMF_MEMORY_TYPE_HOST : {
385- cu_result = g_cu_ops .cuMemAllocHost (resultPtr , size );
448+ cu_result = g_cu_ops .cuMemHostAlloc (resultPtr , size ,
449+ cu_provider -> host_alloc_flags );
386450 break ;
387451 }
388452 case UMF_MEMORY_TYPE_DEVICE : {
389453 cu_result = g_cu_ops .cuMemAlloc ((CUdeviceptr * )resultPtr , size );
390454 break ;
391455 }
392456 case UMF_MEMORY_TYPE_SHARED : {
393- cu_result = g_cu_ops .cuMemAllocManaged (( CUdeviceptr * ) resultPtr , size ,
394- CU_MEM_ATTACH_GLOBAL );
457+ cu_result = g_cu_ops .cuMemAllocManaged (
458+ ( CUdeviceptr * ) resultPtr , size , cu_provider -> managed_alloc_flags );
395459 break ;
396460 }
397461 default :
0 commit comments