1
1
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
2
+ #include < c10/cuda/CUDAException.h>
2
3
#include < c10/cuda/driver_api.h>
3
4
#include < c10/util/CallOnce.h>
4
5
#include < c10/util/Exception.h>
6
+ #include < c10/util/Logging.h>
7
+ #include < cuda_runtime.h>
5
8
#include < dlfcn.h>
6
9
7
10
namespace c10 ::cuda {
8
11
9
12
namespace {
10
13
14
+ void * get_symbol (const char * name, int version);
15
+
11
16
DriverAPI create_driver_api () {
12
- void * handle_0 = dlopen (" libcuda.so.1" , RTLD_LAZY | RTLD_NOLOAD);
13
- TORCH_CHECK (handle_0, " Can't open libcuda.so.1: " , dlerror ());
14
17
void * handle_1 = DriverAPI::get_nvml_handle ();
15
18
DriverAPI r{};
16
19
17
- #define LOOKUP_LIBCUDA_ENTRY (name ) \
18
- r.name ##_ = (( decltype (&name)) dlsym (handle_0, #name)); \
19
- TORCH_INTERNAL_ASSERT (r.name ##_, " Can't find " , #name, " : " , dlerror ())
20
- C10_LIBCUDA_DRIVER_API (LOOKUP_LIBCUDA_ENTRY )
21
- #undef LOOKUP_LIBCUDA_ENTRY
20
+ #define LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED (name, version ) \
21
+ r.name ##_ = reinterpret_cast < decltype (&name)>( get_symbol ( #name, version )); \
22
+ TORCH_INTERNAL_ASSERT (r.name ##_, " Can't find " , #name);
23
+ C10_LIBCUDA_DRIVER_API_REQUIRED (LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED )
24
+ #undef LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED
22
25
23
- #define LOOKUP_LIBCUDA_ENTRY (name ) \
24
- r.name ##_ = ((decltype (&name))dlsym (handle_0, #name)); \
25
- dlerror ();
26
- C10_LIBCUDA_DRIVER_API_12030 (LOOKUP_LIBCUDA_ENTRY)
27
- #undef LOOKUP_LIBCUDA_ENTRY
26
+ // Users running drivers between 12.0 and 12.3 will not have these symbols,
27
+ // they would be resolved into nullptr, but we guard their usage at runtime
28
+ // to ensure safe fallback behavior.
29
+ #define LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL (name, version ) \
30
+ r.name ##_ = reinterpret_cast <decltype (&name)>(get_symbol (#name, version));
31
+ C10_LIBCUDA_DRIVER_API_OPTIONAL (LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL)
32
+ #undef LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL
28
33
29
34
if (handle_1) {
30
35
#define LOOKUP_NVML_ENTRY (name ) \
@@ -35,6 +40,32 @@ DriverAPI create_driver_api() {
35
40
}
36
41
return r;
37
42
}
43
+
44
+ void * get_symbol (const char * name, int version) {
45
+ void * out = nullptr ;
46
+ cudaDriverEntryPointQueryResult qres{};
47
+
48
+ // CUDA 12.5+ supports version-based lookup
49
+ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12050)
50
+ if (auto st = cudaGetDriverEntryPointByVersion (
51
+ name, &out, version, cudaEnableDefault, &qres);
52
+ st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) {
53
+ return out;
54
+ }
55
+ #endif
56
+
57
+ // This fallback to the old API to try getting the symbol again.
58
+ if (auto st = cudaGetDriverEntryPoint (name, &out, cudaEnableDefault, &qres);
59
+ st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) {
60
+ return out;
61
+ }
62
+
63
+ // If the symbol cannot be resolved, report and return nullptr;
64
+ // the caller is responsible for checking the pointer.
65
+ LOG (INFO) << " Failed to resolve symbol " << name;
66
+ return nullptr ;
67
+ }
68
+
38
69
} // namespace
39
70
40
71
void * DriverAPI::get_nvml_handle () {
0 commit comments