Skip to content

Commit 40e6f09

Browse files
t-ivan-grpytorchmergebot
authored andcommitted
Re-land "Fix thread safety in getCurrentCUDABlasHandle and getCUDABlasLtWorkspace" (pytorch#167722)
Summary: getCurrentCUDABlasHandle() and getCUDABlasLtWorkspace() use static mutable maps that are not protected from concurrent read-and-write. This leads to crashes. This diff adds mutexes to synchronize access to the static maps. Note: this is a re-land of D86316117 / pytorch#167248 (see comments for details) Test Plan: Use a GPU OD, run multi-threaded tests (cuda_cublas_handle_pool_test) with TSAN: ``` buck test fbcode//mode/dev-tsan fbcode//caffe2:cuda_cublas_handle_pool_test -- --stress-runs 100 ``` https://www.internalfb.com/intern/testinfra/testrun/14355223937501118 TSAN output (before synchronization was added): P2026731804 Differential Revision: D86964261 Pull Request resolved: pytorch#167722 Approved by: https://github.com/malfet
1 parent bfddfde commit 40e6f09

File tree

4 files changed

+164
-19
lines changed

4 files changed

+164
-19
lines changed

aten/src/ATen/cuda/CUDAContextLight.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <cstdint>
55
#include <map>
6+
#include <shared_mutex>
67

78
#include <cuda_runtime_api.h>
89
#include <cusparse.h>
@@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
8889
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
8990

9091
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
91-
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
92-
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
92+
struct WorkspaceMapWithMutex {
93+
std::map<std::tuple<void*, void*>, at::DataPtr> map;
94+
std::shared_mutex mutex;
95+
};
96+
97+
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
98+
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
9399
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
94100
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
95101
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();

aten/src/ATen/cuda/CublasHandlePool.cpp

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,27 +99,35 @@ void destroyCublasHandle(cublasHandle_t handle) {
9999
// - Comments of @soumith copied from cuDNN handle pool implementation
100100
#ifdef NO_CUDNN_DESTROY_HANDLE
101101
#else
102-
cublasDestroy(handle);
102+
cublasDestroy(handle);
103103
#endif
104104
}
105105

106106
using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle, destroyCublasHandle>;
107107

108108
} // namespace
109109

110-
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
111-
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
110+
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
111+
static auto& instance = *new WorkspaceMapWithMutex;
112112
return instance;
113113
}
114114

115-
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
116-
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
115+
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
116+
static auto& instance = *new WorkspaceMapWithMutex;
117117
return instance;
118118
}
119119

120120
void clearCublasWorkspaces() {
121-
cublas_handle_stream_to_workspace().clear();
122-
cublaslt_handle_stream_to_workspace().clear();
121+
{
122+
auto& workspace = cublas_handle_stream_to_workspace();
123+
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
124+
workspace.map.clear();
125+
}
126+
{
127+
auto& workspace = cublaslt_handle_stream_to_workspace();
128+
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
129+
workspace.map.clear();
130+
}
123131
}
124132

125133
size_t parseChosenWorkspaceSize() {
@@ -241,20 +249,45 @@ void* getCUDABlasLtWorkspace() {
241249
auto stream = c10::cuda::getCurrentCUDAStream();
242250
cudaStream_t _stream = stream;
243251
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
244-
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
245-
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
252+
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
253+
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
254+
auto workspace_it = workspace.map.find(key);
255+
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
246256
return workspace_it->second.mutable_get();
247257
}
248258
#endif
249259
cublasLtHandle_t handle = getCurrentCUDABlasLtHandle();
250260
auto stream = c10::cuda::getCurrentCUDAStream();
251261
cudaStream_t _stream = stream;
252262
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
253-
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
254-
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
255-
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
263+
264+
auto& workspace = cublaslt_handle_stream_to_workspace();
265+
266+
// Fast path: check if workspace already exists
267+
{
268+
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
269+
auto workspace_it = workspace.map.find(key);
270+
if (workspace_it != workspace.map.end()) {
271+
return workspace_it->second.mutable_get();
272+
}
273+
}
274+
275+
// Slow path: allocate workspace outside the lock
276+
auto new_workspace = getNewCUDABlasLtWorkspace();
277+
278+
// Insert with lock (double-check in case another thread inserted while we
279+
// were allocating)
280+
{
281+
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
282+
auto workspace_it = workspace.map.find(key);
283+
if (workspace_it == workspace.map.end()) {
284+
workspace_it =
285+
workspace.map.emplace(key, std::move(new_workspace)).first;
286+
}
287+
// else: another thread inserted it, our new_workspace will be automatically
288+
// freed
289+
return workspace_it->second.mutable_get();
256290
}
257-
return workspace_it->second.mutable_get();
258291
}
259292

260293
cublasHandle_t getCurrentCUDABlasHandle() {
@@ -300,11 +333,39 @@ cublasHandle_t getCurrentCUDABlasHandle() {
300333
// all the memory and cublas's cudaMallocAsync will return OOM
301334
cudaStream_t _stream = stream;
302335
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
303-
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
304-
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
305-
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
336+
337+
auto& workspace = cublas_handle_stream_to_workspace();
338+
339+
size_t workspace_size = getChosenWorkspaceSize();
340+
341+
// Fast path: check if workspace already exists
342+
{
343+
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
344+
auto workspace_it = workspace.map.find(key);
345+
if (workspace_it != workspace.map.end()) {
346+
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
347+
handle, workspace_it->second.get(), workspace_size));
348+
return handle;
349+
}
350+
}
351+
352+
// Slow path: allocate workspace outside the lock
353+
auto new_workspace = getNewWorkspace();
354+
355+
// Insert with lock (double-check in case another thread inserted while we
356+
// were allocating)
357+
{
358+
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
359+
auto workspace_it = workspace.map.find(key);
360+
if (workspace_it == workspace.map.end()) {
361+
workspace_it =
362+
workspace.map.emplace(key, std::move(new_workspace)).first;
363+
}
364+
// else: another thread inserted it, our new_workspace will be automatically
365+
// freed
366+
TORCH_CUDABLAS_CHECK(
367+
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
306368
}
307-
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
308369
#if !defined(USE_ROCM)
309370
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
310371
// FP32 data type calculations based on the value of the allow_tf32 flag.

aten/src/ATen/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
6161
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
6262
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
6363
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
64+
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
6465
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
6566
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
6667
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <c10/cuda/CUDACachingAllocator.h>
5+
#include <c10/cuda/CUDAGuard.h>
6+
7+
#include <atomic>
8+
#include <thread>
9+
#include <vector>
10+
11+
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
12+
// to verify that the data race fix is working correctly
13+
14+
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
15+
if (!at::cuda::is_available()) {
16+
return;
17+
}
18+
19+
constexpr int num_accessor_threads = 15;
20+
constexpr int num_clear_threads = 5;
21+
constexpr int iterations_per_thread = 50;
22+
23+
std::atomic<bool> stop{false};
24+
std::atomic<int> error_count{0};
25+
std::vector<std::thread> threads;
26+
threads.reserve(num_accessor_threads + num_clear_threads);
27+
28+
// Launch accessor threads
29+
for (int i = 0; i < num_accessor_threads; ++i) {
30+
threads.emplace_back([&stop, &error_count]() {
31+
try {
32+
at::cuda::CUDAGuard device_guard(0);
33+
34+
while (!stop.load(std::memory_order_relaxed)) {
35+
const auto handle = at::cuda::getCurrentCUDABlasHandle();
36+
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
37+
38+
if (handle == nullptr || workspace == nullptr) {
39+
error_count++;
40+
}
41+
}
42+
} catch (const std::exception& e) {
43+
error_count++;
44+
}
45+
});
46+
}
47+
48+
// Launch threads that clear workspaces
49+
for (int i = 0; i < num_clear_threads; ++i) {
50+
threads.emplace_back([&error_count]() {
51+
try {
52+
for (int j = 0; j < iterations_per_thread; ++j) {
53+
at::cuda::clearCublasWorkspaces();
54+
std::this_thread::yield();
55+
}
56+
} catch (const std::exception& e) {
57+
error_count++;
58+
}
59+
});
60+
}
61+
62+
// Let them run for a bit
63+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
64+
stop.store(true, std::memory_order_relaxed);
65+
66+
for (auto& thread : threads) {
67+
thread.join();
68+
}
69+
70+
EXPECT_EQ(error_count.load(), 0);
71+
}
72+
73+
int main(int argc, char* argv[]) {
74+
::testing::InitGoogleTest(&argc, argv);
75+
c10::cuda::CUDACachingAllocator::init(1);
76+
return RUN_ALL_TESTS();
77+
}

0 commit comments

Comments
 (0)