@@ -51,6 +51,10 @@ typedef struct cu_ops_t {
5151
5252 CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
5353 CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
54+ CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
55+ CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
56+ unsigned int Flags );
57+ CUresult (* cuIpcCloseMemHandle )(CUdeviceptr dptr );
5458} cu_ops_t ;
5559
5660static cu_ops_t g_cu_ops ;
@@ -117,11 +121,19 @@ static void init_cu_global_state(void) {
117121 utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
118122 * (void * * )& g_cu_ops .cuGetErrorString =
119123 utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
124+ * (void * * )& g_cu_ops .cuIpcGetMemHandle =
125+ utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
126+ * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
127+ utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
128+ * (void * * )& g_cu_ops .cuIpcCloseMemHandle =
129+ utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
120130
121131 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
122132 !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
123133 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
124- !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ) {
134+ !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
135+ !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
136+ !g_cu_ops .cuIpcCloseMemHandle ) {
125137 LOG_ERR ("Required CUDA symbols not found." );
126138 Init_cu_global_state_failed = true;
127139 }
@@ -352,6 +364,99 @@ static const char *cu_memory_provider_get_name(void *provider) {
352364 return "CUDA" ;
353365}
354366
367+ typedef CUipcMemHandle cu_ipc_data_t ;
368+
369+ static umf_result_t cu_memory_provider_get_ipc_handle_size (void * provider ,
370+ size_t * size ) {
371+ if (provider == NULL || size == NULL ) {
372+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
373+ }
374+
375+ * size = sizeof (cu_ipc_data_t );
376+ return UMF_RESULT_SUCCESS ;
377+ }
378+
379+ static umf_result_t cu_memory_provider_get_ipc_handle (void * provider ,
380+ const void * ptr ,
381+ size_t size ,
382+ void * providerIpcData ) {
383+ (void )size ;
384+
385+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
386+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
387+ }
388+
389+ CUresult cu_result ;
390+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
391+
392+ cu_result = g_cu_ops .cuIpcGetMemHandle (cu_ipc_data , (CUdeviceptr )ptr );
393+ if (cu_result != CUDA_SUCCESS ) {
394+ LOG_ERR ("cuIpcGetMemHandle() failed." );
395+ return cu2umf_result (cu_result );
396+ }
397+
398+ return UMF_RESULT_SUCCESS ;
399+ }
400+
401+ static umf_result_t cu_memory_provider_put_ipc_handle (void * provider ,
402+ void * providerIpcData ) {
403+ if (provider == NULL || providerIpcData == NULL ) {
404+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
405+ }
406+
407+ return UMF_RESULT_SUCCESS ;
408+ }
409+
410+ static umf_result_t cu_memory_provider_open_ipc_handle (void * provider ,
411+ void * providerIpcData ,
412+ void * * ptr ) {
413+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
414+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
415+ }
416+
417+ cu_memory_provider_t * cu_provider = (cu_memory_provider_t * )provider ;
418+
419+ CUresult cu_result ;
420+ cu_ipc_data_t * cu_ipc_data = (cu_ipc_data_t * )providerIpcData ;
421+
422+ // Remember current context and set the one from the provider
423+ CUcontext restore_ctx = NULL ;
424+ umf_result_t umf_result = set_context (cu_provider -> context , & restore_ctx );
425+ if (umf_result != UMF_RESULT_SUCCESS ) {
426+ return umf_result ;
427+ }
428+
429+ cu_result = g_cu_ops .cuIpcOpenMemHandle ((CUdeviceptr * )ptr , * cu_ipc_data ,
430+ CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS );
431+
432+ if (cu_result != CUDA_SUCCESS ) {
433+ LOG_ERR ("cuIpcOpenMemHandle() failed." );
434+ }
435+
436+ set_context (restore_ctx , & restore_ctx );
437+
438+ return cu2umf_result (cu_result );
439+ }
440+
441+ static umf_result_t
442+ cu_memory_provider_close_ipc_handle (void * provider , void * ptr , size_t size ) {
443+ (void )size ;
444+
445+ if (provider == NULL || ptr == NULL ) {
446+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
447+ }
448+
449+ CUresult cu_result ;
450+
451+ cu_result = g_cu_ops .cuIpcCloseMemHandle ((CUdeviceptr )ptr );
452+ if (cu_result != CUDA_SUCCESS ) {
453+ LOG_ERR ("cuIpcCloseMemHandle() failed." );
454+ return cu2umf_result (cu_result );
455+ }
456+
457+ return UMF_RESULT_SUCCESS ;
458+ }
459+
355460static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
356461 .version = UMF_VERSION_CURRENT ,
357462 .initialize = cu_memory_provider_initialize ,
@@ -368,12 +473,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
368473 .ext.purge_force = cu_memory_provider_purge_force,
369474 .ext.allocation_merge = cu_memory_provider_allocation_merge,
370475 .ext.allocation_split = cu_memory_provider_allocation_split,
476+ */
371477 .ipc .get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size ,
372478 .ipc .get_ipc_handle = cu_memory_provider_get_ipc_handle ,
373479 .ipc .put_ipc_handle = cu_memory_provider_put_ipc_handle ,
374480 .ipc .open_ipc_handle = cu_memory_provider_open_ipc_handle ,
375481 .ipc .close_ipc_handle = cu_memory_provider_close_ipc_handle ,
376- */
377482};
378483
379484umf_memory_provider_ops_t * umfCUDAMemoryProviderOps (void ) {
0 commit comments