Skip to content

Commit c7b32fe

Browse files
authored
Add cublas_handle() to expose cublas_handle to ops (#31157) (#31190)
* add get_cublas_handle() api * update format * add unittests * alter function name
1 parent 0def593 commit c7b32fe

File tree

4 files changed

+12
-1
lines changed

4 files changed

+12
-1
lines changed

paddle/fluid/platform/cuda_helper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,14 @@ class CublasHandleHolder {
9696
#endif // CUDA_VERSION >= 9000
9797
}
9898

99+
const cublasHandle_t& GetCublasHandle() const { return handle_; }
100+
99101
~CublasHandleHolder() PADDLE_MAY_THROW {
100102
PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasDestroy(handle_));
101103
}
102104

103105
template <typename Callback>
104-
inline void Call(Callback &&callback) const {
106+
inline void Call(Callback&& callback) const {
105107
std::lock_guard<std::mutex> guard(mtx_);
106108
callback(handle_);
107109
}

paddle/fluid/platform/device_context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
437437
return context()->CudnnHandle();
438438
}
439439

440+
cublasHandle_t CUDADeviceContext::cublas_handle() const {
441+
return context()->CublasHandle()->GetCublasHandle();
442+
}
443+
440444
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
441445
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
442446
}

paddle/fluid/platform/device_context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ class CUDADeviceContext : public DeviceContext {
346346
/*! \brief Return cudnn handle in the device context. */
347347
cudnnHandle_t cudnn_handle() const;
348348

349+
/*! \brief Return cublas handle in the device context. */
350+
cublasHandle_t cublas_handle() const;
351+
349352
/*! \brief Return a cudnn workspace handle to call multiple cudnn
350353
* functions without interrupting by other threads.
351354
* Once the first cudnn function is called by the handle, a lock

paddle/fluid/platform/device_context_test.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ TEST(Device, CUDADeviceContext) {
4343
ASSERT_NE(nullptr, gpu_device);
4444
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
4545
ASSERT_NE(nullptr, cudnn_handle);
46+
cublasHandle_t cublas_handle = device_context->cublas_handle();
47+
ASSERT_NE(nullptr, cublas_handle);
4648
delete device_context;
4749
}
4850
}

0 commit comments

Comments
 (0)