@@ -139,6 +139,7 @@ typedef struct cu_ops_t {
139139 CUresult (* cuGetErrorName )(CUresult error , const char * * pStr );
140140 CUresult (* cuGetErrorString )(CUresult error , const char * * pStr );
141141 CUresult (* cuCtxGetCurrent )(CUcontext * pctx );
142+ CUresult (* cuCtxGetDevice )(CUdevice * device );
142143 CUresult (* cuCtxSetCurrent )(CUcontext ctx );
143144 CUresult (* cuIpcGetMemHandle )(CUipcMemHandle * pHandle , CUdeviceptr dptr );
144145 CUresult (* cuIpcOpenMemHandle )(CUdeviceptr * pdptr , CUipcMemHandle handle ,
@@ -224,6 +225,8 @@ static void init_cu_global_state(void) {
224225 utils_get_symbol_addr (lib_handle , "cuGetErrorString" , lib_name );
225226 * (void * * )& g_cu_ops .cuCtxGetCurrent =
226227 utils_get_symbol_addr (lib_handle , "cuCtxGetCurrent" , lib_name );
228+ * (void * * )& g_cu_ops .cuCtxGetDevice =
229+ utils_get_symbol_addr (lib_handle , "cuCtxGetDevice" , lib_name );
227230 * (void * * )& g_cu_ops .cuCtxSetCurrent =
228231 utils_get_symbol_addr (lib_handle , "cuCtxSetCurrent" , lib_name );
229232 * (void * * )& g_cu_ops .cuIpcGetMemHandle =
@@ -237,9 +240,9 @@ static void init_cu_global_state(void) {
237240 !g_cu_ops .cuMemHostAlloc || !g_cu_ops .cuMemAllocManaged ||
238241 !g_cu_ops .cuMemFree || !g_cu_ops .cuMemFreeHost ||
239242 !g_cu_ops .cuGetErrorName || !g_cu_ops .cuGetErrorString ||
240- !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxSetCurrent ||
241- !g_cu_ops .cuIpcGetMemHandle || !g_cu_ops .cuIpcOpenMemHandle ||
242- !g_cu_ops .cuIpcCloseMemHandle ) {
243+ !g_cu_ops .cuCtxGetCurrent || !g_cu_ops .cuCtxGetDevice ||
244+ !g_cu_ops .cuCtxSetCurrent || !g_cu_ops .cuIpcGetMemHandle ||
245+ !g_cu_ops .cuIpcOpenMemHandle || ! g_cu_ops . cuIpcCloseMemHandle ) {
243246 LOG_FATAL ("Required CUDA symbols not found." );
244247 Init_cu_global_state_failed = true;
245248 utils_close_library (lib_handle );
@@ -263,8 +266,29 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
263266 return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY ;
264267 }
265268
266- params_data -> cuda_context_handle = NULL ;
267- params_data -> cuda_device_handle = -1 ;
269+ utils_init_once (& cu_is_initialized , init_cu_global_state );
270+ if (Init_cu_global_state_failed ) {
271+ LOG_FATAL ("Loading CUDA symbols failed" );
272+ return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE ;
273+ }
274+
275+ // initialize context and device to the current ones
276+ CUcontext current_ctx = NULL ;
277+ CUresult cu_result = g_cu_ops .cuCtxGetCurrent (& current_ctx );
278+ if (cu_result == CUDA_SUCCESS ) {
279+ params_data -> cuda_context_handle = current_ctx ;
280+ } else {
281+ params_data -> cuda_context_handle = NULL ;
282+ }
283+
284+ CUdevice current_device = -1 ;
285+ cu_result = g_cu_ops .cuCtxGetDevice (& current_device );
286+ if (cu_result == CUDA_SUCCESS ) {
287+ params_data -> cuda_device_handle = current_device ;
288+ } else {
289+ params_data -> cuda_device_handle = -1 ;
290+ }
291+
268292 params_data -> memory_type = UMF_MEMORY_TYPE_UNKNOWN ;
269293 params_data -> alloc_flags = 0 ;
270294
@@ -345,6 +369,12 @@ static umf_result_t cu_memory_provider_initialize(void *params,
345369 }
346370
347371 if (cu_params -> cuda_context_handle == NULL ) {
372+ LOG_ERR ("Invalid context handle" );
373+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
374+ }
375+
376+ if (cu_params -> cuda_device_handle < 0 ) {
377+ LOG_ERR ("Invalid device handle" );
348378 return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
349379 }
350380
0 commit comments