Skip to content

Commit 49e6359

Browse files
authored
[libcu++] Dynamically load CUDA library instead of using the runtime (#6899)
* Dynamically load CUDA library instead of using the runtime * Switch back to windows.h * Review feedback * More review feedback
1 parent db78d42 commit 49e6359

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

libcudacxx/include/cuda/__driver/driver_api.h

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
# include <cuda/std/__internal/namespaces.h>
2929
# include <cuda/std/__type_traits/always_false.h>
3030
# include <cuda/std/__type_traits/is_same.h>
31+
# if _CCCL_OS(WINDOWS)
32+
# include <windows.h>
33+
# else
34+
# include <dlfcn.h>
35+
# endif
3136

3237
# include <cuda.h>
3338

@@ -47,21 +52,37 @@ _CCCL_BEGIN_NAMESPACE_CUDA_DRIVER
4752
_CCCL_SUPPRESS_DEPRECATED_PUSH
4853

4954
//! @brief Gets the cuGetProcAddress function pointer.
50-
[[nodiscard]] _CCCL_HOST_API inline auto __getProcAddressFn() -> decltype(cuGetProcAddress)*
55+
[[nodiscard]] _CCCL_PUBLIC_HOST_API inline auto __getProcAddressFn() -> decltype(cuGetProcAddress)*
5156
{
52-
// TODO switch to dlopen of libcuda.so instead of the below
53-
void* __fn;
54-
::cudaDriverEntryPointQueryResult __result;
55-
# if _CCCL_CTK_AT_LEAST(13, 0)
56-
::cudaError_t __status =
57-
::cudaGetDriverEntryPointByVersion("cuGetProcAddress", &__fn, 13000, ::cudaEnableDefault, &__result);
58-
# else
59-
::cudaError_t __status = ::cudaGetDriverEntryPoint("cuGetProcAddress", &__fn, ::cudaEnableDefault, &__result);
60-
# endif
61-
if (__status != ::cudaSuccess || __result != ::cudaDriverEntryPointSuccess)
57+
const char* __fn_name = "cuGetProcAddress_v2";
58+
# if _CCCL_OS(WINDOWS)
59+
static auto __driver_library = ::LoadLibraryExA("nvcuda.dll", nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32);
60+
if (__driver_library == nullptr)
61+
{
62+
::cuda::__throw_cuda_error(::cudaErrorUnknown, "Failed to load nvcuda.dll");
63+
}
64+
static void* __fn = ::GetProcAddress(__driver_library, __fn_name);
65+
if (__fn == nullptr)
66+
{
67+
::cuda::__throw_cuda_error(::cudaErrorInitializationError, "Failed to get cuGetProcAddress from nvcuda.dll");
68+
}
69+
# else // ^^^ _CCCL_OS(WINDOWS) ^^^ / vvv !_CCCL_OS(WINDOWS) vvv
70+
# if _CCCL_OS(ANDROID)
71+
const char* __driver_library_name = "libcuda.so";
72+
# else // ^^^ _CCCL_OS(ANDROID) ^^^ / vvv !_CCCL_OS(ANDROID) vvv
73+
const char* __driver_library_name = "libcuda.so.1";
74+
# endif // ^^^ !_CCCL_OS(ANDROID) ^^^
75+
static void* __driver_library = ::dlopen(__driver_library_name, RTLD_NOW);
76+
if (__driver_library == nullptr)
77+
{
78+
::cuda::__throw_cuda_error(::cudaErrorUnknown, "Failed to load libcuda.so.1");
79+
}
80+
static void* __fn = ::dlsym(__driver_library, __fn_name);
81+
if (__fn == nullptr)
6282
{
63-
::cuda::__throw_cuda_error(::cudaErrorUnknown, "Failed to get cuGetProcAddress");
83+
::cuda::__throw_cuda_error(::cudaErrorInitializationError, "Failed to get cuGetProcAddress from libcuda.so.1");
6484
}
85+
# endif // ^^^ !_CCCL_OS(WINDOWS) ^^^
6586
return reinterpret_cast<decltype(cuGetProcAddress)*>(__fn);
6687
}
6788

@@ -152,7 +173,7 @@ _CCCL_HOST_API inline void __call_driver_fn(Fn __fn, const char* __err_msg, Args
152173
//! @return The address of the symbol.
153174
//!
154175
//! @throws @c cuda::cuda_error if the symbol cannot be obtained or the CUDA driver failed to initialize.
155-
[[nodiscard]] _CCCL_HOST_API inline void*
176+
[[nodiscard]] _CCCL_PUBLIC_HOST_API inline void*
156177
__get_driver_entry_point(const char* __name, [[maybe_unused]] int __major = 12, [[maybe_unused]] int __minor = 0)
157178
{
158179
// Get cuGetProcAddress function and call cuInit(0) only on the first call

0 commit comments

Comments
 (0)