Skip to content

Commit ec5204b

Browse files
authored
Merge pull request #13195 from PaddlePaddle/revert-13078-dev_CudnnHolder
Revert "Add CudnnHolder and use it in Conv and ConvTranspose op"
2 parents 7117641 + 151e169 commit ec5204b

File tree

5 files changed

+73
-196
lines changed

5 files changed

+73
-196
lines changed

paddle/fluid/framework/rw_lock.h

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -56,76 +56,5 @@ struct RWLock {
5656
};
5757
#endif
5858

59-
class RWLockGuard {
60-
public:
61-
enum Status { kUnLock, kWRLock, kRDLock };
62-
63-
RWLockGuard(RWLock* rw_lock, Status init_status)
64-
: lock_(rw_lock), status_(Status::kUnLock) {
65-
switch (init_status) {
66-
case Status::kRDLock: {
67-
RDLock();
68-
break;
69-
}
70-
case Status::kWRLock: {
71-
WRLock();
72-
break;
73-
}
74-
case Status::kUnLock: {
75-
break;
76-
}
77-
}
78-
}
79-
80-
void WRLock() {
81-
switch (status_) {
82-
case Status::kUnLock: {
83-
lock_->WRLock();
84-
status_ = Status::kWRLock;
85-
break;
86-
}
87-
case Status::kWRLock: {
88-
break;
89-
}
90-
case Status::kRDLock: {
91-
PADDLE_THROW(
92-
"Please unlock read lock first before invoking write lock.");
93-
break;
94-
}
95-
}
96-
}
97-
98-
void RDLock() {
99-
switch (status_) {
100-
case Status::kUnLock: {
101-
lock_->RDLock();
102-
status_ = Status::kRDLock;
103-
break;
104-
}
105-
case Status::kRDLock: {
106-
break;
107-
}
108-
case Status::kWRLock: {
109-
PADDLE_THROW(
110-
"Please unlock write lock first before invoking read lock.");
111-
break;
112-
}
113-
}
114-
}
115-
116-
void UnLock() {
117-
if (status_ != Status::kUnLock) {
118-
lock_->UNLock();
119-
status_ = Status::kUnLock;
120-
}
121-
}
122-
123-
~RWLockGuard() { UnLock(); }
124-
125-
private:
126-
RWLock* lock_;
127-
Status status_;
128-
};
129-
13059
} // namespace framework
13160
} // namespace paddle

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
118118
output_channels / groups * output_height * output_width * output_depth;
119119
int group_offset_filter = filter->numel() / groups;
120120
// ------------------- cudnn conv workspace ---------------------
121+
void* cudnn_workspace = nullptr;
121122
size_t workspace_size_in_bytes; // final workspace to allocate.
122123
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
123124
if (user_workspace_size > 0) {
@@ -158,18 +159,20 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
158159
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
159160
"workspace_size to be allocated exceeds the limit");
160161

162+
// Allocate on GPU memory
163+
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
164+
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
161165
// ------------------- cudnn conv forward ---------------------
162166
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
163167
for (int i = 0; i < groups; i++) {
164-
auto cudnn_func = [&](void* cudnn_workspace) {
165-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
166-
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
167-
cudnn_filter_desc, filter_data + i * group_offset_filter,
168-
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
169-
&beta, cudnn_output_desc, output_data + i * group_offset_out));
170-
};
171-
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
168+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
169+
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
170+
cudnn_filter_desc, filter_data + i * group_offset_filter,
171+
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
172+
&beta, cudnn_output_desc, output_data + i * group_offset_out));
172173
}
174+
// Release the cudnn workspace
175+
paddle::memory::Free(gpu, cudnn_workspace);
173176
}
174177
};
175178

@@ -311,41 +314,41 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
311314
cudnn_filter_desc, filter_algo, &tmp_size));
312315
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
313316
}
314-
317+
// ------------------- cudnn conv workspace ---------------------
318+
// Already on GPU
319+
void* cudnn_workspace = nullptr;
320+
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
321+
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
315322
// ------------------- cudnn conv backward data ---------------------
316323
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
317324
if (input_grad) {
318325
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
319326
// Because beta is zero, it is unnecessary to reset input_grad.
320327

321328
for (int i = 0; i < groups; i++) {
322-
auto cudnn_func = [&](void* cudnn_workspace) {
323-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
324-
handle, &alpha, cudnn_filter_desc,
325-
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
326-
output_grad_data + i * group_offset_out, cudnn_conv_desc,
327-
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
328-
cudnn_input_desc, input_grad_data + i * group_offset_in));
329-
};
330-
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
329+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
330+
handle, &alpha, cudnn_filter_desc,
331+
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
332+
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
333+
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
334+
input_grad_data + i * group_offset_in));
331335
}
332336
}
333337
// ------------------- cudnn conv backward filter ---------------------
334338
if (filter_grad) {
335339
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
336340
// Because beta is zero, it is unnecessary to reset filter_grad.
337341
for (int i = 0; i < groups; i++) {
338-
auto cudnn_func = [&](void* cudnn_workspace) {
339-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
340-
handle, &alpha, cudnn_input_desc,
341-
input_data + i * group_offset_in, cudnn_output_grad_desc,
342-
output_grad_data + i * group_offset_out, cudnn_conv_desc,
343-
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
344-
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
345-
};
346-
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
342+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
343+
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
344+
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
345+
cudnn_conv_desc, filter_algo, cudnn_workspace,
346+
workspace_size_in_bytes, &beta, cudnn_filter_desc,
347+
filter_grad_data + i * group_offset_filter));
347348
}
348349
}
350+
// Release the cudnn workspace
351+
paddle::memory::Free(gpu, cudnn_workspace);
349352
}
350353
};
351354

paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
7676
conv_desc.descriptor<T>(paddings, strides, dilations);
7777

7878
// ------------------- cudnn conv workspace ---------------------
79+
void* cudnn_workspace = nullptr;
7980
size_t workspace_size_in_bytes; // final workspace to allocate.
8081
size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes;
8182
if (user_workspace_size > 0) {
@@ -99,21 +100,25 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
99100
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
100101
cudnn_output_desc, algo, &workspace_size_in_bytes));
101102

103+
// Allocate on GPU memory
104+
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
105+
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
106+
102107
// ------------------- cudnn conv transpose forward ---------------------
103108
int input_offset = input->numel() / input->dims()[0] / groups;
104109
int output_offset = output->numel() / output->dims()[0] / groups;
105110
int filter_offset = filter->numel() / groups;
106111
T alpha = 1.0f, beta = 0.0f;
107112
for (int g = 0; g < groups; g++) {
108-
auto cudnn_func = [&](void* cudnn_workspace) {
109-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
110-
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
111-
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
112-
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
113-
cudnn_output_desc, output_data + output_offset * g));
114-
};
115-
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
113+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
114+
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
115+
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
116+
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
117+
cudnn_output_desc, output_data + output_offset * g));
116118
}
119+
120+
// Release the cudnn workspace
121+
paddle::memory::Free(gpu, cudnn_workspace);
117122
}
118123
};
119124

@@ -201,6 +206,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
201206
std::max(workspace_size_in_bytes, bwd_filter_ws_size);
202207
}
203208

209+
// ------------------- cudnn conv workspace ---------------------
210+
// Already on GPU
211+
void* cudnn_workspace = nullptr;
212+
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
213+
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
204214
// ------------------- cudnn conv backward data ---------------------
205215
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
206216
int input_offset = input->numel() / input->dims()[0] / groups;
@@ -212,15 +222,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
212222
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
213223
// Because beta is zero, it is unnecessary to reset input_grad.
214224
for (int g = 0; g < groups; g++) {
215-
auto cudnn_func = [&](void* cudnn_workspace) {
216-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
217-
handle, &alpha, cudnn_output_desc,
218-
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
219-
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
220-
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
221-
input_grad_data + input_offset * g));
222-
};
223-
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
225+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
226+
handle, &alpha, cudnn_output_desc,
227+
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
228+
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
229+
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
230+
input_grad_data + input_offset * g));
224231
}
225232
}
226233

@@ -230,17 +237,17 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
230237
// Because beta is zero, it is unnecessary to reset filter_grad.
231238
// Gradient with respect to the filter
232239
for (int g = 0; g < groups; g++) {
233-
auto cudnn_func = [&](void* cudnn_workspace) {
234-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
235-
handle, &alpha, cudnn_output_desc,
236-
output_grad_data + output_grad_offset * g, cudnn_input_desc,
237-
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
238-
cudnn_workspace, workspace_size_in_bytes, &beta,
239-
cudnn_filter_desc, filter_grad_data + filter_offset * g));
240-
};
241-
dev_ctx.RunCudnnFuncWithWorkspace(cudnn_func, workspace_size_in_bytes);
240+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
241+
handle, &alpha, cudnn_output_desc,
242+
output_grad_data + output_grad_offset * g, cudnn_input_desc,
243+
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
244+
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc,
245+
filter_grad_data + filter_offset * g));
242246
}
243247
}
248+
249+
// Release the cudnn workspace
250+
paddle::memory::Free(gpu, cudnn_workspace);
244251
}
245252
};
246253

paddle/fluid/platform/device_context.cc

Lines changed: 9 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ limitations under the License. */
1616
#include <vector>
1717

1818
#include "paddle/fluid/memory/memory.h"
19-
#ifdef PADDLE_WITH_CUDA
20-
#include "paddle/fluid/framework/rw_lock.h"
21-
#endif
2219

2320
namespace paddle {
2421
namespace platform {
@@ -145,59 +142,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
145142
mutable unsigned int* semaphore_;
146143
};
147144

148-
class CudnnHolder {
149-
public:
150-
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
151-
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
152-
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
153-
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
154-
}
155-
156-
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
157-
158-
void RunFunc(const std::function<void(void*)>& cudnn_func,
159-
size_t required_workspace_len) {
160-
std::lock_guard<std::mutex> lock(mtx_);
161-
if (required_workspace_len > workspace_len_) {
162-
ReallocateWorkspace(required_workspace_len);
163-
}
164-
cudnn_func(workspace_);
165-
}
166-
167-
~CudnnHolder() {
168-
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
169-
if (workspace_ != nullptr) {
170-
paddle::memory::Free(place_, workspace_);
171-
}
172-
}
173-
174-
private:
175-
void ReallocateWorkspace(size_t required_workspace_len) {
176-
if (required_workspace_len <= workspace_len_) {
177-
return;
178-
}
179-
void* new_workspace = paddle::memory::Alloc(place_, required_workspace_len);
180-
if (workspace_ != nullptr) {
181-
// Maybe someone is using the current workspace
182-
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
183-
paddle::memory::Free(place_, workspace_);
184-
}
185-
workspace_ = new_workspace;
186-
workspace_len_ = required_workspace_len;
187-
}
188-
189-
cudnnHandle_t cudnn_handle_;
190-
void* workspace_;
191-
size_t workspace_len_;
192-
193-
const cudaStream_t* stream_; // not owned;
194-
const CUDAPlace place_;
195-
196-
std::mutex mtx_;
197-
};
198-
199-
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
200-
: place_(place), cudnn_holder_(nullptr) {
145+
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
201146
SetDeviceId(place_.device);
202147
compute_capability = GetCUDAComputeCapability(place_.device);
203148
multi_process = GetCUDAMultiProcessors(place_.device);
@@ -209,14 +154,20 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
209154
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
210155
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
211156
if (dynload::HasCUDNN()) {
212-
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
157+
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
158+
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
159+
} else {
160+
cudnn_handle_ = nullptr;
213161
}
214162
}
215163

216164
CUDADeviceContext::~CUDADeviceContext() {
217165
SetDeviceId(place_.device);
218166
Wait();
219167
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
168+
if (cudnn_handle_ != nullptr) {
169+
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
170+
}
220171
eigen_stream_.reset();
221172
eigen_device_.reset();
222173
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
@@ -245,14 +196,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
245196
return cublas_handle_;
246197
}
247198

248-
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
249-
return cudnn_holder_->cudnn_handle();
250-
}
251-
252-
void CUDADeviceContext::RunCudnnFuncWithWorkspace(
253-
const std::function<void(void*)>& cudnn_func, size_t workspace_len) const {
254-
cudnn_holder_->RunFunc(cudnn_func, workspace_len);
255-
}
199+
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
256200

257201
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
258202

0 commit comments

Comments
 (0)