Skip to content

Commit 2436177

Browse files
committed
Fix CUDA provider to use proper context
1 parent 27d4fba commit 2436177

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

src/provider/provider_cuda.c

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ typedef struct cu_ops_t {
5151

5252
CUresult (*cuGetErrorName)(CUresult error, const char **pStr);
5353
CUresult (*cuGetErrorString)(CUresult error, const char **pStr);
54+
CUresult (*cuCtxGetCurrent)(CUcontext* pctx);
55+
CUresult (*cuCtxSetCurrent)(CUcontext ctx);
5456
CUresult (*cuIpcGetMemHandle)(CUipcMemHandle *pHandle, CUdeviceptr dptr);
5557
CUresult (*cuIpcOpenMemHandle)(CUdeviceptr *pdptr, CUipcMemHandle handle,
5658
unsigned int Flags);
@@ -121,6 +123,10 @@ static void init_cu_global_state(void) {
121123
utils_get_symbol_addr(0, "cuGetErrorName", lib_name);
122124
*(void **)&g_cu_ops.cuGetErrorString =
123125
utils_get_symbol_addr(0, "cuGetErrorString", lib_name);
126+
*(void **)&g_cu_ops.cuCtxGetCurrent =
127+
utils_get_symbol_addr(0, "cuCtxGetCurrent", lib_name);
128+
*(void **)&g_cu_ops.cuCtxSetCurrent =
129+
utils_get_symbol_addr(0, "cuCtxSetCurrent", lib_name);
124130
*(void **)&g_cu_ops.cuIpcGetMemHandle =
125131
utils_get_symbol_addr(0, "cuIpcGetMemHandle", lib_name);
126132
*(void **)&g_cu_ops.cuIpcOpenMemHandle =
@@ -132,6 +138,7 @@ static void init_cu_global_state(void) {
132138
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
133139
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
134140
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
141+
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
135142
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
136143
!g_cu_ops.cuIpcCloseMemHandle) {
137144
LOG_ERR("Required CUDA symbols not found.");
@@ -202,6 +209,31 @@ static void cu_memory_provider_finalize(void *provider) {
202209
umf_ba_global_free(provider);
203210
}
204211

212+
/*
213+
* This function is used by the CUDA provider to make sure that
214+
* the required context is set. If the current context is
215+
* not the required one, it will be saved in restore_ctx.
216+
*/
217+
static inline umf_result_t set_context(CUcontext required_ctx,
218+
CUcontext *restore_ctx) {
219+
CUcontext current_ctx = NULL;
220+
CUresult cu_result = g_cu_ops.cuCtxGetCurrent(&current_ctx);
221+
if (cu_result != CUDA_SUCCESS) {
222+
LOG_ERR("cuCtxGetCurrent() failed.");
223+
return cu2umf_result(cu_result);
224+
}
225+
*restore_ctx = current_ctx;
226+
if (current_ctx != required_ctx) {
227+
cu_result = g_cu_ops.cuCtxSetCurrent(required_ctx);
228+
if (cu_result != CUDA_SUCCESS) {
229+
LOG_ERR("cuCtxSetCurrent() failed.");
230+
return cu2umf_result(cu_result);
231+
}
232+
}
233+
234+
return UMF_RESULT_SUCCESS;
235+
}
236+
205237
static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
206238
size_t alignment,
207239
void **resultPtr) {
@@ -217,6 +249,13 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
217249
return UMF_RESULT_ERROR_NOT_SUPPORTED;
218250
}
219251

252+
// Remember current context and set the one from the provider
253+
CUcontext restore_ctx = NULL;
254+
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
255+
if (umf_result != UMF_RESULT_SUCCESS) {
256+
return umf_result;
257+
}
258+
220259
CUresult cu_result = CUDA_SUCCESS;
221260
switch (cu_provider->memory_type) {
222261
case UMF_MEMORY_TYPE_HOST: {
@@ -236,16 +275,21 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
236275
// this shouldn't happen as we check the memory_type settings during
237276
// the initialization
238277
LOG_ERR("unsupported USM memory type");
239-
return UMF_RESULT_ERROR_UNKNOWN;
278+
assert(false);
240279
}
241280

242281
// check the alignment
243282
if (alignment > 0 && ((uintptr_t)(*resultPtr) % alignment) != 0) {
244283
cu_memory_provider_free(provider, *resultPtr, size);
245284
LOG_ERR("unsupported alignment size");
285+
set_context(restore_ctx, &restore_ctx);
246286
return UMF_RESULT_ERROR_INVALID_ALIGNMENT;
247287
}
248288

289+
umf_result = set_context(restore_ctx, &restore_ctx);
290+
if (umf_result != UMF_RESULT_SUCCESS) {
291+
return umf_result;
292+
}
249293
return cu2umf_result(cu_result);
250294
}
251295

0 commit comments

Comments
 (0)