@@ -99,27 +99,35 @@ void destroyCublasHandle(cublasHandle_t handle) {
9999// - Comments of @soumith copied from cuDNN handle pool implementation
100100#ifdef NO_CUDNN_DESTROY_HANDLE
101101#else
102- cublasDestroy (handle);
102+ cublasDestroy (handle);
103103#endif
104104}
105105
106106using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle, destroyCublasHandle>;
107107
108108} // namespace
109109
110- std::map<std::tuple< void *, void *>, at::DataPtr> & cublas_handle_stream_to_workspace () {
111- static auto & instance = *new std::map<std::tuple< void *, void *>, at::DataPtr> ;
110+ WorkspaceMapWithMutex & cublas_handle_stream_to_workspace () {
111+ static auto & instance = *new WorkspaceMapWithMutex ;
112112 return instance;
113113}
114114
115- std::map<std::tuple< void *, void *>, at::DataPtr> & cublaslt_handle_stream_to_workspace () {
116- static auto & instance = *new std::map<std::tuple< void *, void *>, at::DataPtr> ;
115+ WorkspaceMapWithMutex & cublaslt_handle_stream_to_workspace () {
116+ static auto & instance = *new WorkspaceMapWithMutex ;
117117 return instance;
118118}
119119
120120void clearCublasWorkspaces () {
121- cublas_handle_stream_to_workspace ().clear ();
122- cublaslt_handle_stream_to_workspace ().clear ();
121+ {
122+ auto & workspace = cublas_handle_stream_to_workspace ();
123+ std::unique_lock<std::shared_mutex> lock (workspace.mutex );
124+ workspace.map .clear ();
125+ }
126+ {
127+ auto & workspace = cublaslt_handle_stream_to_workspace ();
128+ std::unique_lock<std::shared_mutex> lock (workspace.mutex );
129+ workspace.map .clear ();
130+ }
123131}
124132
125133size_t parseChosenWorkspaceSize () {
@@ -241,20 +249,45 @@ void* getCUDABlasLtWorkspace() {
241249 auto stream = c10::cuda::getCurrentCUDAStream ();
242250 cudaStream_t _stream = stream;
243251 auto key = std::make_tuple (static_cast <void *>(handle), static_cast <void *>(_stream));
244- auto workspace_it = at::cuda::cublas_handle_stream_to_workspace ().find (key);
245- TORCH_INTERNAL_ASSERT (workspace_it != at::cuda::cublas_handle_stream_to_workspace ().end ());
252+ auto & workspace = at::cuda::cublas_handle_stream_to_workspace ();
253+ std::shared_lock<std::shared_mutex> lock (workspace.mutex );
254+ auto workspace_it = workspace.map .find (key);
255+ TORCH_INTERNAL_ASSERT (workspace_it != workspace.map .end ());
246256 return workspace_it->second .mutable_get ();
247257 }
248258#endif
249259 cublasLtHandle_t handle = getCurrentCUDABlasLtHandle ();
250260 auto stream = c10::cuda::getCurrentCUDAStream ();
251261 cudaStream_t _stream = stream;
252262 auto key = std::make_tuple (static_cast <void *>(handle), static_cast <void *>(_stream));
253- auto workspace_it = cublaslt_handle_stream_to_workspace ().find (key);
254- if (workspace_it == cublaslt_handle_stream_to_workspace ().end ()) {
255- workspace_it = cublaslt_handle_stream_to_workspace ().insert (workspace_it, {key, getNewCUDABlasLtWorkspace ()});
263+
264+ auto & workspace = cublaslt_handle_stream_to_workspace ();
265+
266+ // Fast path: check if workspace already exists
267+ {
268+ std::shared_lock<std::shared_mutex> lock (workspace.mutex );
269+ auto workspace_it = workspace.map .find (key);
270+ if (workspace_it != workspace.map .end ()) {
271+ return workspace_it->second .mutable_get ();
272+ }
273+ }
274+
275+ // Slow path: allocate workspace outside the lock
276+ auto new_workspace = getNewCUDABlasLtWorkspace ();
277+
278+ // Insert with lock (double-check in case another thread inserted while we
279+ // were allocating)
280+ {
281+ std::unique_lock<std::shared_mutex> lock (workspace.mutex );
282+ auto workspace_it = workspace.map .find (key);
283+ if (workspace_it == workspace.map .end ()) {
284+ workspace_it =
285+ workspace.map .emplace (key, std::move (new_workspace)).first ;
286+ }
287+ // else: another thread inserted it, our new_workspace will be automatically
288+ // freed
289+ return workspace_it->second .mutable_get ();
256290 }
257- return workspace_it->second .mutable_get ();
258291}
259292
260293cublasHandle_t getCurrentCUDABlasHandle () {
@@ -300,11 +333,39 @@ cublasHandle_t getCurrentCUDABlasHandle() {
300333 // all the memory and cublas's cudaMallocAsync will return OOM
301334 cudaStream_t _stream = stream;
302335 auto key = std::make_tuple (static_cast <void *>(handle), static_cast <void *>(_stream));
303- auto workspace_it = cublas_handle_stream_to_workspace ().find (key);
304- if (workspace_it == cublas_handle_stream_to_workspace ().end ()) {
305- workspace_it = cublas_handle_stream_to_workspace ().insert (workspace_it, {key, getNewWorkspace ()});
336+
337+ auto & workspace = cublas_handle_stream_to_workspace ();
338+
339+ size_t workspace_size = getChosenWorkspaceSize ();
340+
341+ // Fast path: check if workspace already exists
342+ {
343+ std::shared_lock<std::shared_mutex> lock (workspace.mutex );
344+ auto workspace_it = workspace.map .find (key);
345+ if (workspace_it != workspace.map .end ()) {
346+ TORCH_CUDABLAS_CHECK (cublasSetWorkspace (
347+ handle, workspace_it->second .get (), workspace_size));
348+ return handle;
349+ }
350+ }
351+
352+ // Slow path: allocate workspace outside the lock
353+ auto new_workspace = getNewWorkspace ();
354+
355+ // Insert with lock (double-check in case another thread inserted while we
356+ // were allocating)
357+ {
358+ std::unique_lock<std::shared_mutex> lock (workspace.mutex );
359+ auto workspace_it = workspace.map .find (key);
360+ if (workspace_it == workspace.map .end ()) {
361+ workspace_it =
362+ workspace.map .emplace (key, std::move (new_workspace)).first ;
363+ }
364+ // else: another thread inserted it, our new_workspace will be automatically
365+ // freed
366+ TORCH_CUDABLAS_CHECK (
367+ cublasSetWorkspace (handle, workspace_it->second .get (), workspace_size));
306368 }
307- TORCH_CUDABLAS_CHECK (cublasSetWorkspace (handle, workspace_it->second .get (), getChosenWorkspaceSize ()));
308369#if !defined(USE_ROCM)
309370 // On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
310371 // FP32 data type calculations based on the value of the allow_tf32 flag.
0 commit comments