2828# include < cuda/std/__internal/namespaces.h>
2929# include < cuda/std/__type_traits/always_false.h>
3030# include < cuda/std/__type_traits/is_same.h>
31+ # if _CCCL_OS(WINDOWS)
32+ # include < windows.h>
33+ # else
34+ # include < dlfcn.h>
35+ # endif
3136
3237# include < cuda.h>
3338
@@ -47,21 +52,37 @@ _CCCL_BEGIN_NAMESPACE_CUDA_DRIVER
4752_CCCL_SUPPRESS_DEPRECATED_PUSH
4853
4954// ! @brief Gets the cuGetProcAddress function pointer.
50- [[nodiscard]] _CCCL_HOST_API inline auto __getProcAddressFn() -> decltype(cuGetProcAddress)*
55+ [[nodiscard]] _CCCL_PUBLIC_HOST_API inline auto __getProcAddressFn() -> decltype(cuGetProcAddress)*
5156{
52- // TODO switch to dlopen of libcuda.so instead of the below
53- void * __fn;
54- ::cudaDriverEntryPointQueryResult __result;
55- # if _CCCL_CTK_AT_LEAST(13, 0)
56- ::cudaError_t __status =
57- ::cudaGetDriverEntryPointByVersion (" cuGetProcAddress" , &__fn, 13000 , ::cudaEnableDefault, &__result);
58- # else
59- ::cudaError_t __status = ::cudaGetDriverEntryPoint (" cuGetProcAddress" , &__fn, ::cudaEnableDefault, &__result);
60- # endif
61- if (__status != ::cudaSuccess || __result != ::cudaDriverEntryPointSuccess)
57+ const char * __fn_name = " cuGetProcAddress_v2" ;
58+ # if _CCCL_OS(WINDOWS)
59+ static auto __driver_library = ::LoadLibraryExA (" nvcuda.dll" , nullptr , LOAD_LIBRARY_SEARCH_SYSTEM32);
60+ if (__driver_library == nullptr )
61+ {
62+ ::cuda::__throw_cuda_error (::cudaErrorUnknown, " Failed to load nvcuda.dll" );
63+ }
64+ static void * __fn = ::GetProcAddress (__driver_library, __fn_name);
65+ if (__fn == nullptr )
66+ {
67+ ::cuda::__throw_cuda_error (::cudaErrorInitializationError, " Failed to get cuGetProcAddress from nvcuda.dll" );
68+ }
69+ # else // ^^^ _CCCL_OS(WINDOWS) ^^^ / vvv !_CCCL_OS(WINDOWS) vvv
70+ # if _CCCL_OS(ANDROID)
71+ const char * __driver_library_name = " libcuda.so" ;
72+ # else // ^^^ _CCCL_OS(ANDROID) ^^^ / vvv !_CCCL_OS(ANDROID) vvv
73+ const char * __driver_library_name = " libcuda.so.1" ;
74+ # endif // ^^^ !_CCCL_OS(ANDROID) ^^^
75+ static void * __driver_library = ::dlopen (__driver_library_name, RTLD_NOW);
76+ if (__driver_library == nullptr )
77+ {
78+ ::cuda::__throw_cuda_error (::cudaErrorUnknown, " Failed to load libcuda.so.1" );
79+ }
80+ static void * __fn = ::dlsym (__driver_library, __fn_name);
81+ if (__fn == nullptr )
6282 {
63- ::cuda::__throw_cuda_error (::cudaErrorUnknown , " Failed to get cuGetProcAddress" );
83+ ::cuda::__throw_cuda_error (::cudaErrorInitializationError , " Failed to get cuGetProcAddress from libcuda.so.1 " );
6484 }
85+ # endif // ^^^ !_CCCL_OS(WINDOWS) ^^^
6586 return reinterpret_cast <decltype (cuGetProcAddress)*>(__fn);
6687}
6788
@@ -152,7 +173,7 @@ _CCCL_HOST_API inline void __call_driver_fn(Fn __fn, const char* __err_msg, Args
152173// ! @return The address of the symbol.
153174// !
154175// ! @throws @c cuda::cuda_error if the symbol cannot be obtained or the CUDA driver failed to initialize.
155- [[nodiscard]] _CCCL_HOST_API inline void *
176+ [[nodiscard]] _CCCL_PUBLIC_HOST_API inline void *
156177__get_driver_entry_point (const char * __name, [[maybe_unused]] int __major = 12 , [[maybe_unused]] int __minor = 0 )
157178{
158179 // Get cuGetProcAddress function and call cuInit(0) only on the first call
0 commit comments