Skip to content

Commit f2b69a0

Browse files
pytorchboteee4017nWEIdia
authored
[CUDA] Use runtime driver API for cuStreamWriteValue32 (pytorch#158585)
[CUDA] Use runtime driver API for cuStreamWriteValue32 (pytorch#158295) Reopen pytorch#156097 Fixes pytorch#154073 Reference: NVIDIA/Fuser#4197 See PR pytorch#156097 and pytorch#154097 Pull Request resolved: pytorch#158295 Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/eqy, https://github.com/huydhn (cherry picked from commit a9f902a) Co-authored-by: Frank Lin <[email protected]> Co-authored-by: Wei Wang <[email protected]>
1 parent 8dd8510 commit f2b69a0

File tree

3 files changed

+81
-38
lines changed

3 files changed

+81
-38
lines changed

c10/cuda/driver_api.cpp

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,35 @@
11
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
2+
#include <c10/cuda/CUDAException.h>
23
#include <c10/cuda/driver_api.h>
34
#include <c10/util/CallOnce.h>
45
#include <c10/util/Exception.h>
6+
#include <c10/util/Logging.h>
7+
#include <cuda_runtime.h>
58
#include <dlfcn.h>
69

710
namespace c10::cuda {
811

912
namespace {
1013

14+
void* get_symbol(const char* name, int version);
15+
1116
DriverAPI create_driver_api() {
12-
void* handle_0 = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_NOLOAD);
13-
TORCH_CHECK(handle_0, "Can't open libcuda.so.1: ", dlerror());
1417
void* handle_1 = DriverAPI::get_nvml_handle();
1518
DriverAPI r{};
1619

17-
#define LOOKUP_LIBCUDA_ENTRY(name) \
18-
r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
19-
TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name, ": ", dlerror())
20-
C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
21-
#undef LOOKUP_LIBCUDA_ENTRY
20+
#define LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED(name, version) \
21+
r.name##_ = reinterpret_cast<decltype(&name)>(get_symbol(#name, version)); \
22+
TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name);
23+
C10_LIBCUDA_DRIVER_API_REQUIRED(LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED)
24+
#undef LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED
2225

23-
#define LOOKUP_LIBCUDA_ENTRY(name) \
24-
r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
25-
dlerror();
26-
C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY)
27-
#undef LOOKUP_LIBCUDA_ENTRY
26+
// Users running drivers between 12.0 and 12.3 will not have these symbols,
27+
// they would be resolved into nullptr, but we guard their usage at runtime
28+
// to ensure safe fallback behavior.
29+
#define LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL(name, version) \
30+
r.name##_ = reinterpret_cast<decltype(&name)>(get_symbol(#name, version));
31+
C10_LIBCUDA_DRIVER_API_OPTIONAL(LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL)
32+
#undef LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL
2833

2934
if (handle_1) {
3035
#define LOOKUP_NVML_ENTRY(name) \
@@ -35,6 +40,32 @@ DriverAPI create_driver_api() {
3540
}
3641
return r;
3742
}
43+
44+
void* get_symbol(const char* name, int version) {
45+
void* out = nullptr;
46+
cudaDriverEntryPointQueryResult qres{};
47+
48+
// CUDA 12.5+ supports version-based lookup
49+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12050)
50+
if (auto st = cudaGetDriverEntryPointByVersion(
51+
name, &out, version, cudaEnableDefault, &qres);
52+
st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) {
53+
return out;
54+
}
55+
#endif
56+
57+
// This fallback to the old API to try getting the symbol again.
58+
if (auto st = cudaGetDriverEntryPoint(name, &out, cudaEnableDefault, &qres);
59+
st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) {
60+
return out;
61+
}
62+
63+
// If the symbol cannot be resolved, report and return nullptr;
64+
// the caller is responsible for checking the pointer.
65+
LOG(INFO) << "Failed to resolve symbol " << name;
66+
return nullptr;
67+
}
68+
3869
} // namespace
3970

4071
void* DriverAPI::get_nvml_handle() {

c10/cuda/driver_api.h

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,42 @@
2020
} \
2121
} while (0)
2222

23-
#define C10_LIBCUDA_DRIVER_API(_) \
24-
_(cuDeviceGetAttribute) \
25-
_(cuMemAddressReserve) \
26-
_(cuMemRelease) \
27-
_(cuMemMap) \
28-
_(cuMemAddressFree) \
29-
_(cuMemSetAccess) \
30-
_(cuMemUnmap) \
31-
_(cuMemCreate) \
32-
_(cuMemGetAllocationGranularity) \
33-
_(cuMemExportToShareableHandle) \
34-
_(cuMemImportFromShareableHandle) \
35-
_(cuMemsetD32Async) \
36-
_(cuStreamWriteValue32) \
37-
_(cuGetErrorString)
23+
// The integer in the second column specifies the requested CUDA Driver API
24+
// version. The dynamic loader will accept a driver with a newer version, but it
25+
// ensures that the requested symbol exists in *at least* the specified version
26+
// or earlier.
27+
28+
// Keep these requested versions as low as possible to maximize compatibility
29+
// across different driver versions.
30+
31+
// Why do we pin to an older version instead of using the latest?
32+
// If a user installs a newer driver, blindly resolving the symbol may bind to a
33+
// newer version of the function with different behavior, potentially breaking
34+
// PyTorch.
35+
36+
#define C10_LIBCUDA_DRIVER_API_REQUIRED(_) \
37+
_(cuDeviceGetAttribute, 12000) \
38+
_(cuMemAddressReserve, 12000) \
39+
_(cuMemRelease, 12000) \
40+
_(cuMemMap, 12000) \
41+
_(cuMemAddressFree, 12000) \
42+
_(cuMemSetAccess, 12000) \
43+
_(cuMemUnmap, 12000) \
44+
_(cuMemCreate, 12000) \
45+
_(cuMemGetAllocationGranularity, 12000) \
46+
_(cuMemExportToShareableHandle, 12000) \
47+
_(cuMemImportFromShareableHandle, 12000) \
48+
_(cuMemsetD32Async, 12000) \
49+
_(cuStreamWriteValue32, 12000) \
50+
_(cuGetErrorString, 12000)
3851

3952
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
40-
#define C10_LIBCUDA_DRIVER_API_12030(_) \
41-
_(cuMulticastAddDevice) \
42-
_(cuMulticastBindMem) \
43-
_(cuMulticastCreate)
53+
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
54+
_(cuMulticastAddDevice, 12030) \
55+
_(cuMulticastBindMem, 12030) \
56+
_(cuMulticastCreate, 12030)
4457
#else
45-
#define C10_LIBCUDA_DRIVER_API_12030(_)
58+
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_)
4659
#endif
4760

4861
#define C10_NVML_DRIVER_API(_) \
@@ -56,11 +69,14 @@
5669
namespace c10::cuda {
5770

5871
struct DriverAPI {
72+
#define CREATE_MEMBER_VERSIONED(name, version) decltype(&name) name##_;
5973
#define CREATE_MEMBER(name) decltype(&name) name##_;
60-
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
61-
C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
74+
C10_LIBCUDA_DRIVER_API_REQUIRED(CREATE_MEMBER_VERSIONED)
75+
C10_LIBCUDA_DRIVER_API_OPTIONAL(CREATE_MEMBER_VERSIONED)
6276
C10_NVML_DRIVER_API(CREATE_MEMBER)
77+
#undef CREATE_MEMBER_VERSIONED
6378
#undef CREATE_MEMBER
79+
6480
static DriverAPI* get();
6581
static void* get_nvml_handle();
6682
};

test/distributed/test_symmetric_memory.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,10 +1078,6 @@ class SymmMemSingleProcTest(TestCase):
10781078
not TEST_WITH_ROCM and _get_torch_cuda_version() < (12, 0),
10791079
"stream_write_value32 currently only supports cuda version>=12.0",
10801080
)
1081-
@skipIf(
1082-
_get_torch_cuda_version() >= (12, 6),
1083-
"https://github.com/pytorch/pytorch/issues/154073",
1084-
)
10851081
@runOnRocmArch(MI300_ARCH)
10861082
def test_stream_write_value32(self):
10871083
tensor = torch.zeros(4, dtype=torch.uint32, device="cuda")

0 commit comments

Comments
 (0)