Skip to content

Commit 27d4fba

Browse files
committed
IPC API implementation for CUDA provider
1 parent 5bf1b5e commit 27d4fba

File tree

1 file changed

+107
-2
lines changed

1 file changed

+107
-2
lines changed

src/provider/provider_cuda.c

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ typedef struct cu_ops_t {
5151

5252
CUresult (*cuGetErrorName)(CUresult error, const char **pStr);
5353
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
54+
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
55+
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
56+
unsigned int Flags);
57+
CUresult (*cuIpcCloseMemHandle)(CUdeviceptr dptr);
5458
} cu_ops_t;
5559

5660
static cu_ops_t g_cu_ops;
@@ -117,11 +121,19 @@ static void init_cu_global_state(void) {
117121
utils_get_symbol_addr(0, "cuGetErrorName", lib_name);
118122
*(void **)&g_cu_ops.cuGetErrorString =
119123
utils_get_symbol_addr(0, "cuGetErrorString", lib_name);
124+
*(void **)&g_cu_ops.cuIpcGetMemHandle =
125+
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
126+
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
127+
utils_get_symbol_addr(0, "cuIpcOpenMemHandle_v2", lib_name);
128+
*(void **)&g_cu_ops.cuIpcCloseMemHandle =
129+
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);
120130

121131
if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
122132
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
123133
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
124-
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString) {
134+
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
135+
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
136+
!g_cu_ops.cuIpcCloseMemHandle) {
125137
LOG_ERR("Required CUDA symbols not found.");
126138
Init_cu_global_state_failed = true;
127139
}
@@ -352,6 +364,99 @@ static const char *cu_memory_provider_get_name(void *provider) {
352364
return "CUDA";
353365
}
354366

367+
typedef CUipcMemHandle cu_ipc_data_t;
368+
369+
static umf_result_t cu_memory_provider_get_ipc_handle_size(void *provider,
370+
size_t *size) {
371+
if (provider == NULL || size == NULL) {
372+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
373+
}
374+
375+
*size = sizeof(cu_ipc_data_t);
376+
return UMF_RESULT_SUCCESS;
377+
}
378+
379+
static umf_result_t cu_memory_provider_get_ipc_handle(void *provider,
380+
const void *ptr,
381+
size_t size,
382+
void *providerIpcData) {
383+
(void)size;
384+
385+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
386+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
387+
}
388+
389+
CUresult cu_result;
390+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
391+
392+
cu_result = g_cu_ops.cuIpcGetMemHandle(cu_ipc_data, (CUdeviceptr)ptr);
393+
if (cu_result != CUDA_SUCCESS) {
394+
LOG_ERR("cuIpcGetMemHandle() failed.");
395+
return cu2umf_result(cu_result);
396+
}
397+
398+
return UMF_RESULT_SUCCESS;
399+
}
400+
401+
static umf_result_t cu_memory_provider_put_ipc_handle(void *provider,
402+
void *providerIpcData) {
403+
if (provider == NULL || providerIpcData == NULL) {
404+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
405+
}
406+
407+
return UMF_RESULT_SUCCESS;
408+
}
409+
410+
static umf_result_t cu_memory_provider_open_ipc_handle(void *provider,
411+
void *providerIpcData,
412+
void **ptr) {
413+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
414+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
415+
}
416+
417+
cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;
418+
419+
CUresult cu_result;
420+
cu_ipc_data_t *cu_ipc_data = (cu_ipc_data_t *)providerIpcData;
421+
422+
// Remember current context and set the one from the provider
423+
CUcontext restore_ctx = NULL;
424+
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
425+
if (umf_result != UMF_RESULT_SUCCESS) {
426+
return umf_result;
427+
}
428+
429+
cu_result = g_cu_ops.cuIpcOpenMemHandle((CUdeviceptr *)ptr, *cu_ipc_data,
430+
CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS);
431+
432+
if (cu_result != CUDA_SUCCESS) {
433+
LOG_ERR("cuIpcOpenMemHandle() failed.");
434+
}
435+
436+
set_context(restore_ctx, &restore_ctx);
437+
438+
return cu2umf_result(cu_result);
439+
}
440+
441+
static umf_result_t
442+
cu_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
443+
(void)size;
444+
445+
if (provider == NULL || ptr == NULL) {
446+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
447+
}
448+
449+
CUresult cu_result;
450+
451+
cu_result = g_cu_ops.cuIpcCloseMemHandle((CUdeviceptr)ptr);
452+
if (cu_result != CUDA_SUCCESS) {
453+
LOG_ERR("cuIpcCloseMemHandle() failed.");
454+
return cu2umf_result(cu_result);
455+
}
456+
457+
return UMF_RESULT_SUCCESS;
458+
}
459+
355460
static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
356461
.version = UMF_VERSION_CURRENT,
357462
.initialize = cu_memory_provider_initialize,
@@ -368,12 +473,12 @@ static struct umf_memory_provider_ops_t UMF_CUDA_MEMORY_PROVIDER_OPS = {
368473
.ext.purge_force = cu_memory_provider_purge_force,
369474
.ext.allocation_merge = cu_memory_provider_allocation_merge,
370475
.ext.allocation_split = cu_memory_provider_allocation_split,
476+
*/
371477
.ipc.get_ipc_handle_size = cu_memory_provider_get_ipc_handle_size,
372478
.ipc.get_ipc_handle = cu_memory_provider_get_ipc_handle,
373479
.ipc.put_ipc_handle = cu_memory_provider_put_ipc_handle,
374480
.ipc.open_ipc_handle = cu_memory_provider_open_ipc_handle,
375481
.ipc.close_ipc_handle = cu_memory_provider_close_ipc_handle,
376-
*/
377482
};
378483

379484
umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {

0 commit comments

Comments
 (0)