@@ -51,6 +51,10 @@ typedef struct cu_ops_t {
5151
5252 CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
5353 CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
54+ CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
55+ CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
56+ unsigned int Flags );
57+ CUresult (* cuIpcCloseMemHandle )(CUdeviceptr dptr );
5458} cu_ops_t ;
5559
5660static cu_ops_t g_cu_ops ;
@@ -117,11 +121,19 @@ static void init_cu_global_state(void) {
117121 utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
118122 * (void * * )& g_cu_ops .cuGetErrorString =
119123 utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
124+ * (void * * )& g_cu_ops .cuIpcGetMemHandle =
125+ utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
126+ * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
127+ utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
128+ * (void * * )& g_cu_ops .cuIpcCloseMemHandle =
129+ utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
120130
121131 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
122132 !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
123133 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
124- !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ) {
134+ !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
135+ !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
136+ !g_cu_ops .cuIpcCloseMemHandle ) {
125137 LOG_ERR ("Required CUDA symbols not found." );
126138 Init_cu_global_state_failed = true;
127139 }
@@ -352,6 +364,89 @@ static const char *cu_memory_provider_get_name(void *provider) {
352364 return "CUDA" ;
353365}
354366
367+ typedef CUipcMemHandle cu_ipc_data_t ;
368+
369+ static umf_result_t cu_memory_provider_get_ipc_handle_size (void * provider ,
370+ size_t * size ) {
371+ if (provider == NULL || size == NULL ) {
372+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
373+ }
374+
375+ * size = sizeof (cu_ipc_data_t );
376+ return UMF_RESULT_SUCCESS ;
377+ }
378+
379+ static umf_result_t cu_memory_provider_get_ipc_handle (void * provider ,
380+ const void * ptr ,
381+ size_t size ,
382+ void * providerIpcData ) {
383+ (void )size ;
384+
385+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
386+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
387+ }
388+
389+ CUresult cu_result ;
390+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
391+
392+ cu_result = g_cu_ops .cuIpcGetMemHandle (cu_ipc_data , (CUdeviceptr )ptr );
393+ if (cu_result != CUDA_SUCCESS ) {
394+ LOG_ERR ("cuIpcGetMemHandle() failed." );
395+ return cu2umf_result (cu_result );
396+ }
397+
398+ return UMF_RESULT_SUCCESS ;
399+ }
400+
401+ static umf_result_t cu_memory_provider_put_ipc_handle (void * provider ,
402+ void * providerIpcData ) {
403+ if (provider == NULL || providerIpcData == NULL ) {
404+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
405+ }
406+
407+ return UMF_RESULT_SUCCESS ;
408+ }
409+
410+ static umf_result_t cu_memory_provider_open_ipc_handle (void * provider ,
411+ void * providerIpcData ,
412+ void * * ptr ) {
413+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
414+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
415+ }
416+
417+ CUresult cu_result ;
418+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
419+
420+ cu_result = g_cu_ops .cuIpcOpenMemHandle ((CUdeviceptr * )ptr , * cu_ipc_data ,
421+ CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS );
422+
423+ if (cu_result != CUDA_SUCCESS ) {
424+ LOG_ERR ("cuIpcOpenMemHandle() failed." );
425+ return cu2umf_result (cu_result );
426+ }
427+
428+ return UMF_RESULT_SUCCESS ;
429+ }
430+
431+ static umf_result_t
432+ cu_memory_provider_close_ipc_handle (void * provider , void * ptr , size_t size ) {
433+ (void )size ;
434+
435+ if (provider == NULL || ptr == NULL ) {
436+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
437+ }
438+
439+ CUresult cu_result ;
440+
441+ cu_result = g_cu_ops .cuIpcCloseMemHandle ((CUdeviceptr )ptr );
442+ if (cu_result != CUDA_SUCCESS ) {
443+ LOG_ERR ("cuIpcCloseMemHandle() failed." );
444+ return cu2umf_result (cu_result );
445+ }
446+
447+ return UMF_RESULT_SUCCESS ;
448+ }
449+
355450static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
356451 .version = UMF_VERSION_CURRENT ,
357452 .initialize = cu_memory_provider_initialize ,
@@ -368,12 +463,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
368463 .ext.purge_force = cu_memory_provider_purge_force,
369464 .ext.allocation_merge = cu_memory_provider_allocation_merge,
370465 .ext.allocation_split = cu_memory_provider_allocation_split,
466+ */
371467 .ipc .get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size ,
372468 .ipc .get_ipc_handle = cu_memory_provider_get_ipc_handle ,
373469 .ipc .put_ipc_handle = cu_memory_provider_put_ipc_handle ,
374470 .ipc .open_ipc_handle = cu_memory_provider_open_ipc_handle ,
375471 .ipc .close_ipc_handle = cu_memory_provider_close_ipc_handle ,
376- */
377472};
378473
379474umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
0 commit comments