Skip to content

Commit d4a9701

Browse files
committed
IPC API implementation for CUDA provider
1 parent 5bf1b5e commit d4a9701

File tree

1 file changed

+97
-2
lines changed

1 file changed

+97
-2
lines changed

src/provider/provider_cuda.c

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ typedef struct cu_ops_t {
5151

5252
CUresult (*cuGetErrorName)(CUresult error, const char **pStr);
5353
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
54+
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
55+
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
56+
unsigned int Flags);
57+
CUresult (*cuIpcCloseMemHandle)(CUdeviceptr dptr);
5458
} cu_ops_t;
5559

5660
static cu_ops_t g_cu_ops;
@@ -117,11 +121,19 @@ static void init_cu_global_state(void) {
117121
utils_get_symbol_addr(0, "cuGetErrorName", lib_name);
118122
*(void **)&g_cu_ops.cuGetErrorString =
119123
utils_get_symbol_addr(0, "cuGetErrorString", lib_name);
124+
*(void **)&g_cu_ops.cuIpcGetMemHandle =
125+
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
126+
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
127+
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
128+
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
129+
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
120130

121131
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
122132
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
123133
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
124-
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString) {
134+
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
135+
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
136+
!g_cu_ops.cuIpcCloseMemHandle) {
125137
LOG_ERR("Required CUDA symbols not found.");
126138
Init_cu_global_state_failed = true;
127139
}
@@ -352,6 +364,89 @@ static const char *cu_memory_provider_get_name(void *provider) {
352364
return "CUDA";
353365
}
354366

367+
typedef CUipcMemHandle cu_ipc_data_t;
368+
369+
static umf_result_t cu_memory_provider_get_ipc_handle_size(void *provider,
370+
size_t *size) {
371+
if (provider == NULL || size == NULL) {
372+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
373+
}
374+
375+
*size = sizeof(cu_ipc_data_t);
376+
return UMF_RESULT_SUCCESS;
377+
}
378+
379+
static umf_result_t cu_memory_provider_get_ipc_handle(void *provider,
380+
const void *ptr,
381+
size_t size,
382+
void *providerIpcData) {
383+
(void)size;
384+
385+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
386+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
387+
}
388+
389+
CUresult cu_result;
390+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
391+
392+
cu_result = g_cu_ops.cuIpcGetMemHandle(cu_ipc_data, (CUdeviceptr)ptr);
393+
if (cu_result != CUDA_SUCCESS) {
394+
LOG_ERR("cuIpcGetMemHandle() failed.");
395+
return cu2umf_result(cu_result);
396+
}
397+
398+
return UMF_RESULT_SUCCESS;
399+
}
400+
401+
static umf_result_t cu_memory_provider_put_ipc_handle(void *provider,
402+
void *providerIpcData) {
403+
if (provider == NULL || providerIpcData == NULL) {
404+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
405+
}
406+
407+
return UMF_RESULT_SUCCESS;
408+
}
409+
410+
static umf_result_t cu_memory_provider_open_ipc_handle(void *provider,
411+
void *providerIpcData,
412+
void **ptr) {
413+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
414+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
415+
}
416+
417+
CUresult cu_result;
418+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
419+
420+
cu_result = g_cu_ops.cuIpcOpenMemHandle((CUdeviceptr *)ptr, *cu_ipc_data,
421+
CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS);
422+
423+
if (cu_result != CUDA_SUCCESS) {
424+
LOG_ERR("cuIpcOpenMemHandle() failed.");
425+
return cu2umf_result(cu_result);
426+
}
427+
428+
return UMF_RESULT_SUCCESS;
429+
}
430+
431+
static umf_result_t
432+
cu_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
433+
(void)size;
434+
435+
if (provider == NULL || ptr == NULL) {
436+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
437+
}
438+
439+
CUresult cu_result;
440+
441+
cu_result = g_cu_ops.cuIpcCloseMemHandle((CUdeviceptr)ptr);
442+
if (cu_result != CUDA_SUCCESS) {
443+
LOG_ERR("cuIpcCloseMemHandle() failed.");
444+
return cu2umf_result(cu_result);
445+
}
446+
447+
return UMF_RESULT_SUCCESS;
448+
}
449+
355450
static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
356451
.version = UMF_VERSION_CURRENT,
357452
.initialize = cu_memory_provider_initialize,
@@ -368,12 +463,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
368463
.ext.purge_force = cu_memory_provider_purge_force,
369464
.ext.allocation_merge = cu_memory_provider_allocation_merge,
370465
.ext.allocation_split = cu_memory_provider_allocation_split,
466+
*/
371467
.ipc.get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size,
372468
.ipc.get_ipc_handle = cu_memory_provider_get_ipc_handle,
373469
.ipc.put_ipc_handle = cu_memory_provider_put_ipc_handle,
374470
.ipc.open_ipc_handle = cu_memory_provider_open_ipc_handle,
375471
.ipc.close_ipc_handle = cu_memory_provider_close_ipc_handle,
376-
*/
377472
};
378473

379474
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {

0 commit comments

Comments
 (0)