Skip to content

Commit 76e9227

Browse files
authored
Merge pull request #13199 from JiayiFeng/fix_CudnnHolder_bug
Fix cudnn holder bug
2 parents 17bf871 + 8331e83 commit 76e9227

File tree

5 files changed

+195
-73
lines changed

5 files changed

+195
-73
lines changed

paddle/fluid/framework/rw_lock.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,76 @@ struct RWLock {
5656
};
5757
#endif
5858

59+
class RWLockGuard {
60+
public:
61+
enum Status { kUnLock, kWRLock, kRDLock };
62+
63+
RWLockGuard(RWLock* rw_lock, Status init_status)
64+
: lock_(rw_lock), status_(Status::kUnLock) {
65+
switch (init_status) {
66+
case Status::kRDLock: {
67+
RDLock();
68+
break;
69+
}
70+
case Status::kWRLock: {
71+
WRLock();
72+
break;
73+
}
74+
case Status::kUnLock: {
75+
break;
76+
}
77+
}
78+
}
79+
80+
void WRLock() {
81+
switch (status_) {
82+
case Status::kUnLock: {
83+
lock_->WRLock();
84+
status_ = Status::kWRLock;
85+
break;
86+
}
87+
case Status::kWRLock: {
88+
break;
89+
}
90+
case Status::kRDLock: {
91+
PADDLE_THROW(
92+
"Please unlock read lock first before invoking write lock.");
93+
break;
94+
}
95+
}
96+
}
97+
98+
void RDLock() {
99+
switch (status_) {
100+
case Status::kUnLock: {
101+
lock_->RDLock();
102+
status_ = Status::kRDLock;
103+
break;
104+
}
105+
case Status::kRDLock: {
106+
break;
107+
}
108+
case Status::kWRLock: {
109+
PADDLE_THROW(
110+
"Please unlock write lock first before invoking read lock.");
111+
break;
112+
}
113+
}
114+
}
115+
116+
void UnLock() {
117+
if (status_ != Status::kUnLock) {
118+
lock_->UNLock();
119+
status_ = Status::kUnLock;
120+
}
121+
}
122+
123+
~RWLockGuard() { UnLock(); }
124+
125+
private:
126+
RWLock* lock_;
127+
Status status_;
128+
};
129+
59130
} // namespace framework
60131
} // namespace paddle

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
118118
output_channels / groups * output_height * output_width * output_depth;
119119
int group_offset_filter = filter->numel() / groups;
120120
// ------------------- cudnn conv workspace ---------------------
121-
void* cudnn_workspace = nullptr;
122121
size_t workspace_size_in_bytes; // final workspace to allocate.
123122
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
124123
if (user_workspace_size > 0) {
@@ -159,20 +158,18 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
159158
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
160159
"workspace_size to be allocated exceeds the limit");
161160

162-
// Allocate on GPU memory
163-
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
164-
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
165161
// ------------------- cudnn conv forward ---------------------
166162
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
167163
for (int i = 0; i < groups; i++) {
168-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
169-
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
170-
cudnn_filter_desc, filter_data + i * group_offset_filter,
171-
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
172-
&beta, cudnn_output_desc, output_data + i * group_offset_out));
164+
auto cudnn_func = [&](void* cudnn_workspace) {
165+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
166+
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
167+
cudnn_filter_desc, filter_data + i * group_offset_filter,
168+
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
169+
&beta, cudnn_output_desc, output_data + i * group_offset_out));
170+
};
171+
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
173172
}
174-
// Release the cudnn workspace
175-
paddle::memory::Free(gpu, cudnn_workspace);
176173
}
177174
};
178175

@@ -314,41 +311,41 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
314311
cudnn_filter_desc, filter_algo, &tmp_size));
315312
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
316313
}
317-
// ------------------- cudnn conv workspace ---------------------
318-
// Already on GPU
319-
void* cudnn_workspace = nullptr;
320-
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
321-
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
314+
322315
// ------------------- cudnn conv backward data ---------------------
323316
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
324317
if (input_grad) {
325318
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
326319
// Because beta is zero, it is unnecessary to reset input_grad.
327320

328321
for (int i = 0; i < groups; i++) {
329-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
330-
handle, &alpha, cudnn_filter_desc,
331-
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
332-
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
333-
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
334-
input_grad_data + i * group_offset_in));
322+
auto cudnn_func = [&](void* cudnn_workspace) {
323+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
324+
handle, &alpha, cudnn_filter_desc,
325+
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
326+
output_grad_data + i * group_offset_out, cudnn_conv_desc,
327+
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
328+
cudnn_input_desc, input_grad_data + i * group_offset_in));
329+
};
330+
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
335331
}
336332
}
337333
// ------------------- cudnn conv backward filter ---------------------
338334
if (filter_grad) {
339335
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
340336
// Because beta is zero, it is unnecessary to reset filter_grad.
341337
for (int i = 0; i < groups; i++) {
342-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
343-
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
344-
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
345-
cudnn_conv_desc, filter_algo, cudnn_workspace,
346-
workspace_size_in_bytes, &beta, cudnn_filter_desc,
347-
filter_grad_data + i * group_offset_filter));
338+
auto cudnn_func = [&](void* cudnn_workspace) {
339+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
340+
handle, &alpha, cudnn_input_desc,
341+
input_data + i * group_offset_in, cudnn_output_grad_desc,
342+
output_grad_data + i * group_offset_out, cudnn_conv_desc,
343+
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
344+
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
345+
};
346+
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
348347
}
349348
}
350-
// Release the cudnn workspace
351-
paddle::memory::Free(gpu, cudnn_workspace);
352349
}
353350
};
354351

paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
7676
conv_desc.descriptor<T>(paddings, strides, dilations);
7777

7878
// ------------------- cudnn conv workspace ---------------------
79-
void* cudnn_workspace = nullptr;
8079
size_t workspace_size_in_bytes; // final workspace to allocate.
8180
size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes;
8281
if (user_workspace_size > 0) {
@@ -100,25 +99,21 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
10099
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
101100
cudnn_output_desc, algo, &workspace_size_in_bytes));
102101

103-
// Allocate on GPU memory
104-
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
105-
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
106-
107102
// ------------------- cudnn conv transpose forward ---------------------
108103
int input_offset = input->numel() / input->dims()[0] / groups;
109104
int output_offset = output->numel() / output->dims()[0] / groups;
110105
int filter_offset = filter->numel() / groups;
111106
T alpha = 1.0f, beta = 0.0f;
112107
for (int g = 0; g < groups; g++) {
113-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
114-
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
115-
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
116-
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
117-
cudnn_output_desc, output_data + output_offset * g));
108+
auto cudnn_func = [&](void* cudnn_workspace) {
109+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
110+
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
111+
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
112+
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
113+
cudnn_output_desc, output_data + output_offset * g));
114+
};
115+
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
118116
}
119-
120-
// Release the cudnn workspace
121-
paddle::memory::Free(gpu, cudnn_workspace);
122117
}
123118
};
124119

@@ -206,11 +201,6 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
206201
std::max(workspace_size_in_bytes, bwd_filter_ws_size);
207202
}
208203

209-
// ------------------- cudnn conv workspace ---------------------
210-
// Already on GPU
211-
void* cudnn_workspace = nullptr;
212-
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
213-
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
214204
// ------------------- cudnn conv backward data ---------------------
215205
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
216206
int input_offset = input->numel() / input->dims()[0] / groups;
@@ -222,12 +212,15 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
222212
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
223213
// Because beta is zero, it is unnecessary to reset input_grad.
224214
for (int g = 0; g < groups; g++) {
225-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
226-
handle, &alpha, cudnn_output_desc,
227-
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
228-
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
229-
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
230-
input_grad_data + input_offset * g));
215+
auto cudnn_func = [&](void* cudnn_workspace) {
216+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
217+
handle, &alpha, cudnn_output_desc,
218+
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
219+
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
220+
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
221+
input_grad_data + input_offset * g));
222+
};
223+
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
231224
}
232225
}
233226

@@ -237,17 +230,17 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
237230
// Because beta is zero, it is unnecessary to reset filter_grad.
238231
// Gradient with respect to the filter
239232
for (int g = 0; g < groups; g++) {
240-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
241-
handle, &alpha, cudnn_output_desc,
242-
output_grad_data + output_grad_offset * g, cudnn_input_desc,
243-
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
244-
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc,
245-
filter_grad_data + filter_offset * g));
233+
auto cudnn_func = [&](void* cudnn_workspace) {
234+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
235+
handle, &alpha, cudnn_output_desc,
236+
output_grad_data + output_grad_offset * g, cudnn_input_desc,
237+
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
238+
cudnn_workspace, workspace_size_in_bytes, &beta,
239+
cudnn_filter_desc, filter_grad_data + filter_offset * g));
240+
};
241+
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
246242
}
247243
}
248-
249-
// Release the cudnn workspace
250-
paddle::memory::Free(gpu, cudnn_workspace);
251244
}
252245
};
253246

paddle/fluid/platform/device_context.cc

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ limitations under the License. */
1616
#include <vector>
1717

1818
#include "paddle/fluid/memory/memory.h"
19+
#ifdef PADDLE_WITH_CUDA
20+
#include "paddle/fluid/framework/rw_lock.h"
21+
#endif
1922

2023
namespace paddle {
2124
namespace platform {
@@ -142,7 +145,58 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
142145
mutable unsigned int* semaphore_;
143146
};
144147

145-
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
148+
class CudnnHolder {
149+
public:
150+
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
151+
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
152+
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
153+
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
154+
}
155+
156+
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
157+
158+
void RunFunc(const std::function<void(void*)>& cudnn_func,
159+
size_t required_workspace_len) {
160+
std::lock_guard<std::mutex> lock(mtx_);
161+
if (required_workspace_len > workspace_len_) {
162+
ReallocateWorkspace(required_workspace_len);
163+
}
164+
cudnn_func(workspace_);
165+
}
166+
167+
~CudnnHolder() {
168+
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
169+
if (workspace_ != nullptr) {
170+
paddle::memory::Free(place_, workspace_);
171+
}
172+
}
173+
174+
private:
175+
void ReallocateWorkspace(size_t required_workspace_len) {
176+
if (required_workspace_len <= workspace_len_) {
177+
return;
178+
}
179+
if (workspace_ != nullptr) {
180+
// Maybe someone is using the current workspace
181+
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
182+
paddle::memory::Free(place_, workspace_);
183+
}
184+
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
185+
workspace_len_ = required_workspace_len;
186+
}
187+
188+
cudnnHandle_t cudnn_handle_;
189+
void* workspace_;
190+
size_t workspace_len_;
191+
192+
const cudaStream_t* stream_; // not owned;
193+
const CUDAPlace place_;
194+
195+
std::mutex mtx_;
196+
};
197+
198+
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
199+
: place_(place), cudnn_holder_(nullptr) {
146200
SetDeviceId(place_.device);
147201
compute_capability = GetCUDAComputeCapability(place_.device);
148202
multi_process = GetCUDAMultiProcessors(place_.device);
@@ -154,20 +208,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
154208
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
155209
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
156210
if (dynload::HasCUDNN()) {
157-
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
158-
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
159-
} else {
160-
cudnn_handle_ = nullptr;
211+
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
161212
}
162213
}
163214

164215
CUDADeviceContext::~CUDADeviceContext() {
165216
SetDeviceId(place_.device);
166217
Wait();
167218
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
168-
if (cudnn_handle_ != nullptr) {
169-
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
170-
}
171219
eigen_stream_.reset();
172220
eigen_device_.reset();
173221
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
@@ -196,7 +244,14 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
196244
return cublas_handle_;
197245
}
198246

199-
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
247+
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
248+
return cudnn_holder_->cudnn_handle();
249+
}
250+
251+
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
252+
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
253+
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
254+
}
200255

201256
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
202257

0 commit comments

Comments
 (0)