@@ -53,6 +53,10 @@ typedef struct cu_ops_t {
5353 CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
5454 CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
5555 CUresult (* cuCtxSetCurrent )(CUcontext ctx );
56+ CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
57+ CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
58+ unsigned int Flags );
59+ CUresult (* cuIpcCloseMemHandle )(CUdeviceptr dptr );
5660} cu_ops_t ;
5761
5862static cu_ops_t g_cu_ops ;
@@ -123,12 +127,20 @@ static void init_cu_global_state(void) {
123127 utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
124128 * (void * * )& g_cu_ops .cuCtxSetCurrent =
125129 utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
130+ * (void * * )& g_cu_ops .cuIpcGetMemHandle =
131+ utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
132+ * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
133+ utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
134+ * (void * * )& g_cu_ops .cuIpcCloseMemHandle =
135+ utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
126136
127137 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
128138 !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
129139 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
130140 !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
131- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ) {
141+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
142+ !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
143+ !g_cu_ops .cuIpcCloseMemHandle ) {
132144 LOG_ERR ("Required CUDA symbols not found." );
133145 Init_cu_global_state_failed = true;
134146 }
@@ -396,6 +408,99 @@ static const char *cu_memory_provider_get_name(void *provider) {
396408 return "CUDA" ;
397409}
398410
411+ typedef CUipcMemHandle cu_ipc_data_t ;
412+
413+ static umf_result_t cu_memory_provider_get_ipc_handle_size (void * provider ,
414+ size_t * size ) {
415+ if (provider == NULL || size == NULL ) {
416+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
417+ }
418+
419+ * size = sizeof (cu_ipc_data_t );
420+ return UMF_RESULT_SUCCESS ;
421+ }
422+
423+ static umf_result_t cu_memory_provider_get_ipc_handle (void * provider ,
424+ const void * ptr ,
425+ size_t size ,
426+ void * providerIpcData ) {
427+ (void )size ;
428+
429+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
430+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
431+ }
432+
433+ CUresult cu_result ;
434+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
435+
436+ cu_result = g_cu_ops .cuIpcGetMemHandle (cu_ipc_data , (CUdeviceptr )ptr );
437+ if (cu_result != CUDA_SUCCESS ) {
438+ LOG_ERR ("cuIpcGetMemHandle() failed." );
439+ return cu2umf_result (cu_result );
440+ }
441+
442+ return UMF_RESULT_SUCCESS ;
443+ }
444+
445+ static umf_result_t cu_memory_provider_put_ipc_handle (void * provider ,
446+ void * providerIpcData ) {
447+ if (provider == NULL || providerIpcData == NULL ) {
448+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
449+ }
450+
451+ return UMF_RESULT_SUCCESS ;
452+ }
453+
454+ static umf_result_t cu_memory_provider_open_ipc_handle (void * provider ,
455+ void * providerIpcData ,
456+ void * * ptr ) {
457+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
458+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
459+ }
460+
461+ cu_memory_provider_t * cu_provider = (cu_memory_provider_t * )provider ;
462+
463+ CUresult cu_result ;
464+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
465+
466+ // Remember current context and set the one from the provider
467+ CUcontext restore_ctx = NULL ;
468+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
469+ if (umf_result != UMF_RESULT_SUCCESS ) {
470+ return umf_result ;
471+ }
472+
473+ cu_result = g_cu_ops .cuIpcOpenMemHandle ((CUdeviceptr * )ptr , * cu_ipc_data ,
474+ CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS );
475+
476+ if (cu_result != CUDA_SUCCESS ) {
477+ LOG_ERR ("cuIpcOpenMemHandle() failed." );
478+ }
479+
480+ set_context (restore_ctx , & restore_ctx );
481+
482+ return cu2umf_result (cu_result );
483+ }
484+
485+ static umf_result_t
486+ cu_memory_provider_close_ipc_handle (void * provider , void * ptr , size_t size ) {
487+ (void )size ;
488+
489+ if (provider == NULL || ptr == NULL ) {
490+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
491+ }
492+
493+ CUresult cu_result ;
494+
495+ cu_result = g_cu_ops .cuIpcCloseMemHandle ((CUdeviceptr )ptr );
496+ if (cu_result != CUDA_SUCCESS ) {
497+ LOG_ERR ("cuIpcCloseMemHandle() failed." );
498+ return cu2umf_result (cu_result );
499+ }
500+
501+ return UMF_RESULT_SUCCESS ;
502+ }
503+
399504static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
400505 .version = UMF_VERSION_CURRENT ,
401506 .initialize = cu_memory_provider_initialize ,
@@ -412,12 +517,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
412517 .ext.purge_force = cu_memory_provider_purge_force,
413518 .ext.allocation_merge = cu_memory_provider_allocation_merge,
414519 .ext.allocation_split = cu_memory_provider_allocation_split,
520+ */
415521 .ipc .get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size ,
416522 .ipc .get_ipc_handle = cu_memory_provider_get_ipc_handle ,
417523 .ipc .put_ipc_handle = cu_memory_provider_put_ipc_handle ,
418524 .ipc .open_ipc_handle = cu_memory_provider_open_ipc_handle ,
419525 .ipc .close_ipc_handle = cu_memory_provider_close_ipc_handle ,
420- */
421526};
422527
423528umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
0 commit comments