1212#include <umf.h>
1313#include <umf/providers/provider_cuda.h>
1414
15+ #include "provider_cuda_internal.h"
16+ #include "utils_load_library.h"
1517#include "utils_log.h"
1618
19+ static void * cu_lib_handle = NULL ;
20+
21+ void fini_cu_global_state (void ) {
22+ if (cu_lib_handle ) {
23+ utils_close_library (cu_lib_handle );
24+ cu_lib_handle = NULL ;
25+ }
26+ }
27+
1728#if defined(UMF_NO_CUDA_PROVIDER )
1829
1930umf_result_t umfCUDAMemoryProviderParamsCreate (
@@ -80,7 +91,6 @@ umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
8091#include "utils_assert.h"
8192#include "utils_common.h"
8293#include "utils_concurrency.h"
83- #include "utils_load_library.h"
8494#include "utils_log.h"
8595#include "utils_sanitizers.h"
8696
@@ -163,37 +173,45 @@ static void init_cu_global_state(void) {
163173#else
164174 const char * lib_name = "libcuda.so" ;
165175#endif
166- // check if CUDA shared library is already loaded
167- // we pass 0 as a handle to search the global symbol table
176+ // The CUDA shared library should be already loaded by the user
177+ // of the CUDA provider. UMF just want to re-use it
178+ // and increase the reference count to the the CUDA shared library.
179+ void * lib_handle =
180+ utils_open_library (lib_name , UMF_UTIL_OPEN_LIBRARY_NO_LOAD );
181+ if (!lib_handle ) {
182+ LOG_ERR ("Failed to open CUDA shared library" );
183+ Init_cu_global_state_failed = true;
184+ return ;
185+ }
168186
169187 // NOTE: some symbols defined in the lib have _vX postfixes - it is
170188 // important to load the proper version of functions
171- * (void * * )& g_cu_ops .cuMemGetAllocationGranularity =
172- utils_get_symbol_addr ( 0 , "cuMemGetAllocationGranularity" , lib_name );
189+ * (void * * )& g_cu_ops .cuMemGetAllocationGranularity = utils_get_symbol_addr (
190+ lib_handle , "cuMemGetAllocationGranularity" , lib_name );
173191 * (void * * )& g_cu_ops .cuMemAlloc =
174- utils_get_symbol_addr (0 , "cuMemAlloc_v2" , lib_name );
192+ utils_get_symbol_addr (lib_handle , "cuMemAlloc_v2" , lib_name );
175193 * (void * * )& g_cu_ops .cuMemAllocHost =
176- utils_get_symbol_addr (0 , "cuMemAllocHost_v2" , lib_name );
194+ utils_get_symbol_addr (lib_handle , "cuMemAllocHost_v2" , lib_name );
177195 * (void * * )& g_cu_ops .cuMemAllocManaged =
178- utils_get_symbol_addr (0 , "cuMemAllocManaged" , lib_name );
196+ utils_get_symbol_addr (lib_handle , "cuMemAllocManaged" , lib_name );
179197 * (void * * )& g_cu_ops .cuMemFree =
180- utils_get_symbol_addr (0 , "cuMemFree_v2" , lib_name );
198+ utils_get_symbol_addr (lib_handle , "cuMemFree_v2" , lib_name );
181199 * (void * * )& g_cu_ops .cuMemFreeHost =
182- utils_get_symbol_addr (0 , "cuMemFreeHost" , lib_name );
200+ utils_get_symbol_addr (lib_handle , "cuMemFreeHost" , lib_name );
183201 * (void * * )& g_cu_ops .cuGetErrorName =
184- utils_get_symbol_addr (0 , "cuGetErrorName" , lib_name );
202+ utils_get_symbol_addr (lib_handle , "cuGetErrorName" , lib_name );
185203 * (void * * )& g_cu_ops .cuGetErrorString =
186- utils_get_symbol_addr (0 , "cuGetErrorString" , lib_name );
204+ utils_get_symbol_addr (lib_handle , "cuGetErrorString" , lib_name );
187205 * (void * * )& g_cu_ops .cuCtxGetCurrent =
188- utils_get_symbol_addr (0 , "cuCtxGetCurrent" , lib_name );
206+ utils_get_symbol_addr (lib_handle , "cuCtxGetCurrent" , lib_name );
189207 * (void * * )& g_cu_ops .cuCtxSetCurrent =
190- utils_get_symbol_addr (0 , "cuCtxSetCurrent" , lib_name );
208+ utils_get_symbol_addr (lib_handle , "cuCtxSetCurrent" , lib_name );
191209 * (void * * )& g_cu_ops .cuIpcGetMemHandle =
192- utils_get_symbol_addr (0 , "cuIpcGetMemHandle" , lib_name );
210+ utils_get_symbol_addr (lib_handle , "cuIpcGetMemHandle" , lib_name );
193211 * (void * * )& g_cu_ops .cuIpcOpenMemHandle =
194- utils_get_symbol_addr (0 , "cuIpcOpenMemHandle_v2" , lib_name );
212+ utils_get_symbol_addr (lib_handle , "cuIpcOpenMemHandle_v2" , lib_name );
195213 * (void * * )& g_cu_ops .cuIpcCloseMemHandle =
196- utils_get_symbol_addr (0 , "cuIpcCloseMemHandle" , lib_name );
214+ utils_get_symbol_addr (lib_handle , "cuIpcCloseMemHandle" , lib_name );
197215
198216 if (!g_cu_ops .cuMemGetAllocationGranularity || !g_cu_ops .cuMemAlloc ||
199217 !g_cu_ops .cuMemAllocHost || !g_cu_ops .cuMemAllocManaged ||
@@ -204,7 +222,10 @@ static void init_cu_global_state(void) {
204222 !g_cu_ops .cuIpcCloseMemHandle ) {
205223 LOG_ERR ("Required CUDA symbols not found." );
206224 Init_cu_global_state_failed = true;
225+ utils_close_library (lib_handle );
226+ return ;
207227 }
228+ cu_lib_handle = lib_handle ;
208229}
209230
210231umf_result_t umfCUDAMemoryProviderParamsCreate (
@@ -297,7 +318,7 @@ static umf_result_t cu_memory_provider_initialize(void *params,
297318 utils_init_once (& cu_is_initialized , init_cu_global_state );
298319 if (Init_cu_global_state_failed ) {
299320 LOG_ERR ("Loading CUDA symbols failed" );
300- return UMF_RESULT_ERROR_UNKNOWN ;
321+ return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE ;
301322 }
302323
303324 cu_memory_provider_t * cu_provider =
0 commit comments