Skip to content

Commit 46d01d7

Browse files
author
chengduo
authored
Revert "Revert "Remove workspace_handle in conv_cudnn (#15186)"" (#15290)
test=develop This reverts commit 358e657.
1 parent a92860a commit 46d01d7

File tree

7 files changed

+195
-92
lines changed

7 files changed

+195
-92
lines changed

paddle/fluid/framework/operator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ class ExecutionContext {
391391
PADDLE_ENFORCE(
392392
dynamic_cast<platform::TemporaryAllocation*>(allocation_ptr) != nullptr,
393393
"The AllocationPtr must be TemporaryAllocation.");
394-
PADDLE_ENFORCE_EQ(allocation_ptr->size(),
394+
PADDLE_ENFORCE_GE(allocation_ptr->size(),
395395
framework::product(dim) * sizeof(T));
396396

397397
paddle::framework::Tensor temp_tensor(

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 85 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
137137
// ------------------- cudnn conv algorithm ---------------------
138138
cudnnConvolutionFwdAlgo_t algo;
139139
auto handle = dev_ctx.cudnn_handle();
140-
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
141140

142141
bool half_float = false;
143142
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
@@ -158,6 +157,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
158157
VLOG(5) << "NOT use cudnn_tensor_op_math";
159158
}
160159
#endif
160+
Tensor cudnn_workspace;
161+
void* cudnn_workspace_ptr = nullptr;
161162

162163
auto x_dims = framework::vectorize(input->dims());
163164
auto f_dims = framework::vectorize(filter->dims());
@@ -180,21 +181,26 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
180181
.Var(kCUDNNFwdAlgoCache)
181182
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
182183
}
184+
cudnn_workspace =
185+
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
186+
framework::make_ddim(
187+
{static_cast<int64_t>(workspace_size_limit)}),
188+
dev_ctx);
189+
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
190+
183191
algo = algo_cache->GetAlgorithm(
184192
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
185193
int returned_algo_count;
186194
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
187195
fwd_perf_stat;
188-
auto cudnn_find_func = [&](void* cudnn_workspace) {
189-
CUDNN_ENFORCE(
190-
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
191-
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
192-
filter_data, cudnn_conv_desc, cudnn_output_desc,
193-
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
194-
fwd_perf_stat.data(), cudnn_workspace,
195-
workspace_size_limit));
196-
};
197-
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
196+
197+
CUDNN_ENFORCE(
198+
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
199+
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
200+
filter_data, cudnn_conv_desc, cudnn_output_desc,
201+
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
202+
fwd_perf_stat.data(), cudnn_workspace_ptr,
203+
workspace_size_limit));
198204

199205
VLOG(3) << "Perf result: (algo: stat, time, memory)";
200206
for (int i = 0; i < returned_algo_count; ++i) {
@@ -219,17 +225,23 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
219225
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
220226
"workspace_size to be allocated exceeds the limit");
221227

228+
// Allocate on GPU memory
229+
if (!cudnn_workspace_ptr) {
230+
cudnn_workspace =
231+
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
232+
framework::make_ddim(
233+
{static_cast<int64_t>(workspace_size_in_bytes)}),
234+
dev_ctx);
235+
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
236+
}
222237
// ------------------- cudnn conv forward ---------------------
223238
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
224239
for (int i = 0; i < groups; i++) {
225-
auto cudnn_func = [&](void* cudnn_workspace) {
226-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
227-
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
228-
cudnn_filter_desc, filter_data + i * group_offset_filter,
229-
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
230-
&beta, cudnn_output_desc, output_data + i * group_offset_out));
231-
};
232-
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
240+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
241+
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
242+
cudnn_filter_desc, filter_data + i * group_offset_filter,
243+
cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes,
244+
&beta, cudnn_output_desc, output_data + i * group_offset_out));
233245
}
234246
}
235247
};
@@ -353,10 +365,20 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
353365
workspace_size_limit = max_user_size * 1024 * 1024;
354366
}
355367

368+
Tensor cudnn_workspace;
369+
void* cudnn_workspace_ptr = nullptr;
370+
if ((input_data || filter_data) && exhaustive_search) {
371+
cudnn_workspace =
372+
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
373+
framework::make_ddim(
374+
{static_cast<int64_t>(workspace_size_limit)}),
375+
dev_ctx);
376+
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
377+
}
378+
356379
auto x_dims = framework::vectorize(input->dims());
357380
auto f_dims = framework::vectorize(filter->dims());
358381
auto handle = dev_ctx.cudnn_handle();
359-
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
360382
if (input_grad) {
361383
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
362384
if (exhaustive_search) {
@@ -374,25 +396,22 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
374396
->GetMutable<
375397
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
376398
}
399+
377400
data_algo = data_algo_cache->GetAlgorithm(
378401
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
379402
int returned_algo_count;
380403
std::array<cudnnConvolutionBwdDataAlgoPerf_t,
381404
kNUM_CUDNN_BWD_DATA_ALGS>
382405
data_perf_stat;
383-
auto cudnn_find_bd_data_func = [&](void* cudnn_workspace) {
384-
CUDNN_ENFORCE(
385-
platform::dynload::
386-
cudnnFindConvolutionBackwardDataAlgorithmEx(
387-
handle, cudnn_filter_desc, filter_data,
388-
cudnn_output_grad_desc, output_grad_data,
389-
cudnn_conv_desc, cudnn_input_desc, input_grad_data,
390-
kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
391-
data_perf_stat.data(), cudnn_workspace,
392-
workspace_size_limit));
393-
};
394-
workspace_handle.RunFunc(cudnn_find_bd_data_func,
395-
workspace_size_limit);
406+
407+
CUDNN_ENFORCE(platform::dynload::
408+
cudnnFindConvolutionBackwardDataAlgorithmEx(
409+
handle, cudnn_filter_desc, filter_data,
410+
cudnn_output_grad_desc, output_grad_data,
411+
cudnn_conv_desc, cudnn_input_desc,
412+
input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS,
413+
&returned_algo_count, data_perf_stat.data(),
414+
cudnn_workspace_ptr, workspace_size_limit));
396415

397416
VLOG(3) << "Perf result: (algo: stat, time, memory)";
398417
for (int i = 0; i < returned_algo_count; ++i) {
@@ -443,25 +462,23 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
443462
->GetMutable<
444463
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
445464
}
465+
446466
filter_algo = f_algo_cache->GetAlgorithm(
447467
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
448468
int returned_algo_count;
449469
std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
450470
kNUM_CUDNN_BWD_FILTER_ALGS>
451471
filter_perf_stat;
452-
auto cudnn_find_bd_f_func = [&](void* cudnn_workspace) {
453-
CUDNN_ENFORCE(
454-
platform::dynload::
455-
cudnnFindConvolutionBackwardFilterAlgorithmEx(
456-
handle, cudnn_input_desc, input_data,
457-
cudnn_output_grad_desc, output_grad_data,
458-
cudnn_conv_desc, cudnn_filter_desc,
459-
filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS,
460-
&returned_algo_count, filter_perf_stat.data(),
461-
cudnn_workspace, workspace_size_limit));
462-
};
463-
workspace_handle.RunFunc(cudnn_find_bd_f_func,
464-
workspace_size_limit);
472+
473+
CUDNN_ENFORCE(
474+
platform::dynload::
475+
cudnnFindConvolutionBackwardFilterAlgorithmEx(
476+
handle, cudnn_input_desc, input_data,
477+
cudnn_output_grad_desc, output_grad_data,
478+
cudnn_conv_desc, cudnn_filter_desc, filter_grad_data,
479+
kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count,
480+
filter_perf_stat.data(), cudnn_workspace_ptr,
481+
workspace_size_limit));
465482
return filter_perf_stat[0].algo;
466483
});
467484
VLOG(3) << "cuDNN backward filter algo " << filter_algo;
@@ -482,38 +499,42 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
482499
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
483500
}
484501

502+
// ------------------- cudnn conv workspace ---------------------
503+
if (!cudnn_workspace_ptr) {
504+
cudnn_workspace =
505+
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
506+
framework::make_ddim(
507+
{static_cast<int64_t>(workspace_size_in_bytes)}),
508+
dev_ctx);
509+
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
510+
}
511+
485512
// ------------------- cudnn conv backward data ---------------------
486513
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
487514
if (input_grad) {
488515
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
489516
// Because beta is zero, it is unnecessary to reset input_grad.
490517

491518
for (int i = 0; i < groups; i++) {
492-
auto cudnn_func = [&](void* cudnn_workspace) {
493-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
494-
handle, &alpha, cudnn_filter_desc,
495-
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
496-
output_grad_data + i * group_offset_out, cudnn_conv_desc,
497-
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
498-
cudnn_input_desc, input_grad_data + i * group_offset_in));
499-
};
500-
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
519+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
520+
handle, &alpha, cudnn_filter_desc,
521+
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
522+
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
523+
cudnn_workspace_ptr, workspace_size_in_bytes, &beta,
524+
cudnn_input_desc, input_grad_data + i * group_offset_in));
501525
}
502526
}
503527
// ------------------- cudnn conv backward filter ---------------------
504528
if (filter_grad) {
505529
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
506530
// Because beta is zero, it is unnecessary to reset filter_grad.
507531
for (int i = 0; i < groups; i++) {
508-
auto cudnn_func = [&](void* cudnn_workspace) {
509-
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
510-
handle, &alpha, cudnn_input_desc,
511-
input_data + i * group_offset_in, cudnn_output_grad_desc,
512-
output_grad_data + i * group_offset_out, cudnn_conv_desc,
513-
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
514-
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
515-
};
516-
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
532+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
533+
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
534+
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
535+
cudnn_conv_desc, filter_algo, cudnn_workspace_ptr,
536+
workspace_size_in_bytes, &beta, cudnn_filter_desc,
537+
filter_grad_data + i * group_offset_filter));
517538
}
518539
}
519540
}

paddle/fluid/platform/device_context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ namespace platform {
6161
* the allocations of temp_allocation_queue:
6262
* - when the Stream calls cudaStreamSynchronize;
6363
* - when the allocation size of opportunities exceeds a certain threshold
64-
* (defined by FLAGS_limit_of_temporary_allocation).
64+
* (defined by FLAGS_limit_of_tmp_allocation).
6565
*
6666
* */
6767
class DeviceTemporaryAllocator {

paddle/fluid/platform/temporary_allocator.cc

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@
1515
#include "paddle/fluid/platform/temporary_allocator.h"
1616
#include "paddle/fluid/memory/allocation/allocator_facade.h"
1717

18-
DEFINE_double(limit_of_temporary_allocation, -1,
19-
"The up limit of temporary_allocation size.");
18+
DEFINE_int64(limit_of_tmp_allocation, -1,
19+
"The up limit of temporary_allocation size.");
20+
DEFINE_double(times_excess_than_required_tmp_allocation, 2,
21+
"times_excess_than_required_tmp_allocation indicates the "
22+
"max size the TemporaryAllocator can return. For example, "
23+
"if the required memory size is N, and "
24+
"times_excess_than_required_tmp_allocation is 2.0, "
25+
"the TemporaryAllocator will return the available allocation "
26+
"that the range of size is N ~ 2*N.");
2027

2128
namespace paddle {
2229
namespace platform {
@@ -29,53 +36,60 @@ TemporaryAllocation::TemporaryAllocation(
2936
underlying_allocation_(std::move(underlying_allocation)) {}
3037

3138
TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) {
32-
temp_mem_queue_.reset(new std::deque<TemporaryAllocation *>());
39+
temp_mem_map_.reset(new std::multimap<size_t, TemporaryAllocation *>());
3340
}
3441

3542
bool TemporaryAllocator::IsAllocThreadSafe() const { return true; }
3643

3744
void TemporaryAllocator::Release(const std::function<void()> &callback) {
38-
std::shared_ptr<std::deque<TemporaryAllocation *>> t_allocations;
45+
std::unique_ptr<std::multimap<size_t, TemporaryAllocation *>> t_allocations;
3946
{
4047
std::unique_lock<std::mutex> lock(mtx_);
4148
callback();
42-
t_allocations = temp_mem_queue_;
43-
temp_mem_queue_.reset(new std::deque<TemporaryAllocation *>());
49+
t_allocations.swap(temp_mem_map_);
50+
temp_mem_map_.reset(new std::multimap<size_t, TemporaryAllocation *>());
4451
wait_delete_mem_ = 0;
4552
}
53+
4654
for (auto tmp : *t_allocations) {
47-
VLOG(10) << "Delete temporary allocation " << tmp->ptr()
48-
<< " size: " << tmp->size();
49-
delete tmp;
55+
VLOG(10) << "Delete temporary allocation " << tmp.second->ptr()
56+
<< " size: " << tmp.second->size();
57+
delete tmp.second;
5058
}
5159
}
5260

5361
void TemporaryAllocator::Free(alloc::Allocation *allocation) {
5462
auto *temp_allocation = dynamic_cast<TemporaryAllocation *>(allocation);
5563
PADDLE_ENFORCE_NOT_NULL(temp_allocation);
5664
if (platform::is_gpu_place(temp_allocation->place())) {
65+
PADDLE_ENFORCE(platform::is_same_place(temp_allocation->place(), place_),
66+
"The place should be the same.");
5767
size_t wait_delete_mem = 0;
5868
{
5969
std::unique_lock<std::mutex> lock(mtx_);
60-
temp_mem_queue_->emplace_back(temp_allocation);
70+
temp_mem_map_->emplace(temp_allocation->size(), temp_allocation);
6171
wait_delete_mem_ += temp_allocation->size();
6272
wait_delete_mem = wait_delete_mem_;
6373
VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr()
6474
<< " to delete queue: " << temp_allocation->size() << "; "
65-
<< "wait_delete_mem: " << wait_delete_mem_;
75+
<< "wait_delete_mem: " << wait_delete_mem;
6676
}
67-
if (FLAGS_limit_of_temporary_allocation > 0 &&
68-
wait_delete_mem > FLAGS_limit_of_temporary_allocation) {
77+
78+
if (FLAGS_limit_of_tmp_allocation > 0 &&
79+
wait_delete_mem > static_cast<size_t>(FLAGS_limit_of_tmp_allocation)) {
80+
PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized.");
6981
Release(callback_);
7082
}
7183
return;
7284
}
85+
VLOG(10) << "Delete temporary allocation " << temp_allocation->ptr()
86+
<< " size: " << temp_allocation->size();
7387
delete temp_allocation;
7488
}
7589

7690
size_t TemporaryAllocator::TemporaryAllocationQueueSize() {
7791
std::unique_lock<std::mutex> lock(mtx_);
78-
return temp_mem_queue_ ? temp_mem_queue_->size() : 0;
92+
return temp_mem_map_ ? temp_mem_map_->size() : 0;
7993
}
8094

8195
void TemporaryAllocator::SetCallback(const std::function<void()> &callback) {
@@ -84,6 +98,27 @@ void TemporaryAllocator::SetCallback(const std::function<void()> &callback) {
8498

8599
alloc::Allocation *TemporaryAllocator::AllocateImpl(
86100
size_t size, alloc::Allocator::Attr attr) {
101+
{
102+
// Find available allocation in temp_mem_map.
103+
std::unique_lock<std::mutex> lock(mtx_);
104+
if (temp_mem_map_->size()) {
105+
auto it = temp_mem_map_->lower_bound(size);
106+
// FIXME(zcd): Not sure the best value of excess fraction.
107+
if (it != temp_mem_map_->end() &&
108+
it->first <
109+
static_cast<size_t>(
110+
size * FLAGS_times_excess_than_required_tmp_allocation)) {
111+
auto tmp_ptr = it->second;
112+
temp_mem_map_->erase(it);
113+
wait_delete_mem_ -= tmp_ptr->size();
114+
VLOG(10) << "Reuse temporary allocation: " << tmp_ptr->ptr() << ": "
115+
<< tmp_ptr->size();
116+
return tmp_ptr;
117+
}
118+
}
119+
}
120+
// If not find the the available allocation, get allocation from
121+
// AllocatorFacadeInstance.
87122
auto raw_allocation =
88123
alloc::AllocatorFacade::Instance().Alloc(place_, size, attr);
89124
auto temp_mem = new TemporaryAllocation(std::move(raw_allocation));

0 commit comments

Comments
 (0)