@@ -137,7 +137,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
137
137
// ------------------- cudnn conv algorithm ---------------------
138
138
cudnnConvolutionFwdAlgo_t algo;
139
139
auto handle = dev_ctx.cudnn_handle ();
140
- auto workspace_handle = dev_ctx.cudnn_workspace_handle ();
141
140
142
141
bool half_float = false ;
143
142
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
@@ -158,6 +157,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
158
157
VLOG (5 ) << " NOT use cudnn_tensor_op_math" ;
159
158
}
160
159
#endif
160
+ Tensor cudnn_workspace;
161
+ void * cudnn_workspace_ptr = nullptr ;
161
162
162
163
auto x_dims = framework::vectorize (input->dims ());
163
164
auto f_dims = framework::vectorize (filter->dims ());
@@ -180,21 +181,26 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
180
181
.Var (kCUDNNFwdAlgoCache )
181
182
->GetMutable <AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
182
183
}
184
+ cudnn_workspace =
185
+ ctx.AllocateTmpTensor <int8_t , platform::CUDADeviceContext>(
186
+ framework::make_ddim (
187
+ {static_cast <int64_t >(workspace_size_limit)}),
188
+ dev_ctx);
189
+ cudnn_workspace_ptr = static_cast <void *>(cudnn_workspace.data <int8_t >());
190
+
183
191
algo = algo_cache->GetAlgorithm (
184
192
x_dims, f_dims, strides, paddings, dilations, 0 , [&]() {
185
193
int returned_algo_count;
186
194
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS >
187
195
fwd_perf_stat;
188
- auto cudnn_find_func = [&](void * cudnn_workspace) {
189
- CUDNN_ENFORCE (
190
- platform::dynload::cudnnFindConvolutionForwardAlgorithmEx (
191
- handle, cudnn_input_desc, input_data, cudnn_filter_desc,
192
- filter_data, cudnn_conv_desc, cudnn_output_desc,
193
- output_data, kNUM_CUDNN_FWD_ALGS , &returned_algo_count,
194
- fwd_perf_stat.data (), cudnn_workspace,
195
- workspace_size_limit));
196
- };
197
- workspace_handle.RunFunc (cudnn_find_func, workspace_size_limit);
196
+
197
+ CUDNN_ENFORCE (
198
+ platform::dynload::cudnnFindConvolutionForwardAlgorithmEx (
199
+ handle, cudnn_input_desc, input_data, cudnn_filter_desc,
200
+ filter_data, cudnn_conv_desc, cudnn_output_desc,
201
+ output_data, kNUM_CUDNN_FWD_ALGS , &returned_algo_count,
202
+ fwd_perf_stat.data (), cudnn_workspace_ptr,
203
+ workspace_size_limit));
198
204
199
205
VLOG (3 ) << " Perf result: (algo: stat, time, memory)" ;
200
206
for (int i = 0 ; i < returned_algo_count; ++i) {
@@ -219,17 +225,23 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
219
225
PADDLE_ENFORCE_LE (workspace_size_in_bytes, workspace_size_limit,
220
226
" workspace_size to be allocated exceeds the limit" );
221
227
228
+ // Allocate on GPU memory
229
+ if (!cudnn_workspace_ptr) {
230
+ cudnn_workspace =
231
+ ctx.AllocateTmpTensor <int8_t , platform::CUDADeviceContext>(
232
+ framework::make_ddim (
233
+ {static_cast <int64_t >(workspace_size_in_bytes)}),
234
+ dev_ctx);
235
+ cudnn_workspace_ptr = static_cast <void *>(cudnn_workspace.data <int8_t >());
236
+ }
222
237
// ------------------- cudnn conv forward ---------------------
223
238
ScalingParamType<T> alpha = 1 .0f , beta = 0 .0f ;
224
239
for (int i = 0 ; i < groups; i++) {
225
- auto cudnn_func = [&](void * cudnn_workspace) {
226
- CUDNN_ENFORCE (platform::dynload::cudnnConvolutionForward (
227
- handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
228
- cudnn_filter_desc, filter_data + i * group_offset_filter,
229
- cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
230
- &beta, cudnn_output_desc, output_data + i * group_offset_out));
231
- };
232
- workspace_handle.RunFunc (cudnn_func, workspace_size_in_bytes);
240
+ CUDNN_ENFORCE (platform::dynload::cudnnConvolutionForward (
241
+ handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
242
+ cudnn_filter_desc, filter_data + i * group_offset_filter,
243
+ cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes,
244
+ &beta, cudnn_output_desc, output_data + i * group_offset_out));
233
245
}
234
246
}
235
247
};
@@ -353,10 +365,20 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
353
365
workspace_size_limit = max_user_size * 1024 * 1024 ;
354
366
}
355
367
368
+ Tensor cudnn_workspace;
369
+ void * cudnn_workspace_ptr = nullptr ;
370
+ if ((input_data || filter_data) && exhaustive_search) {
371
+ cudnn_workspace =
372
+ ctx.AllocateTmpTensor <int8_t , platform::CUDADeviceContext>(
373
+ framework::make_ddim (
374
+ {static_cast <int64_t >(workspace_size_limit)}),
375
+ dev_ctx);
376
+ cudnn_workspace_ptr = static_cast <void *>(cudnn_workspace.data <int8_t >());
377
+ }
378
+
356
379
auto x_dims = framework::vectorize (input->dims ());
357
380
auto f_dims = framework::vectorize (filter->dims ());
358
381
auto handle = dev_ctx.cudnn_handle ();
359
- auto workspace_handle = dev_ctx.cudnn_workspace_handle ();
360
382
if (input_grad) {
361
383
T* input_grad_data = input_grad->mutable_data <T>(ctx.GetPlace ());
362
384
if (exhaustive_search) {
@@ -374,25 +396,22 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
374
396
->GetMutable <
375
397
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
376
398
}
399
+
377
400
data_algo = data_algo_cache->GetAlgorithm (
378
401
x_dims, f_dims, strides, paddings, dilations, 0 , [&]() {
379
402
int returned_algo_count;
380
403
std::array<cudnnConvolutionBwdDataAlgoPerf_t,
381
404
kNUM_CUDNN_BWD_DATA_ALGS >
382
405
data_perf_stat;
383
- auto cudnn_find_bd_data_func = [&](void * cudnn_workspace) {
384
- CUDNN_ENFORCE (
385
- platform::dynload::
386
- cudnnFindConvolutionBackwardDataAlgorithmEx (
387
- handle, cudnn_filter_desc, filter_data,
388
- cudnn_output_grad_desc, output_grad_data,
389
- cudnn_conv_desc, cudnn_input_desc, input_grad_data,
390
- kNUM_CUDNN_BWD_DATA_ALGS , &returned_algo_count,
391
- data_perf_stat.data (), cudnn_workspace,
392
- workspace_size_limit));
393
- };
394
- workspace_handle.RunFunc (cudnn_find_bd_data_func,
395
- workspace_size_limit);
406
+
407
+ CUDNN_ENFORCE (platform::dynload::
408
+ cudnnFindConvolutionBackwardDataAlgorithmEx (
409
+ handle, cudnn_filter_desc, filter_data,
410
+ cudnn_output_grad_desc, output_grad_data,
411
+ cudnn_conv_desc, cudnn_input_desc,
412
+ input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS ,
413
+ &returned_algo_count, data_perf_stat.data (),
414
+ cudnn_workspace_ptr, workspace_size_limit));
396
415
397
416
VLOG (3 ) << " Perf result: (algo: stat, time, memory)" ;
398
417
for (int i = 0 ; i < returned_algo_count; ++i) {
@@ -443,25 +462,23 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
443
462
->GetMutable <
444
463
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
445
464
}
465
+
446
466
filter_algo = f_algo_cache->GetAlgorithm (
447
467
x_dims, f_dims, strides, paddings, dilations, 0 , [&]() {
448
468
int returned_algo_count;
449
469
std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
450
470
kNUM_CUDNN_BWD_FILTER_ALGS >
451
471
filter_perf_stat;
452
- auto cudnn_find_bd_f_func = [&](void * cudnn_workspace) {
453
- CUDNN_ENFORCE (
454
- platform::dynload::
455
- cudnnFindConvolutionBackwardFilterAlgorithmEx (
456
- handle, cudnn_input_desc, input_data,
457
- cudnn_output_grad_desc, output_grad_data,
458
- cudnn_conv_desc, cudnn_filter_desc,
459
- filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS ,
460
- &returned_algo_count, filter_perf_stat.data (),
461
- cudnn_workspace, workspace_size_limit));
462
- };
463
- workspace_handle.RunFunc (cudnn_find_bd_f_func,
464
- workspace_size_limit);
472
+
473
+ CUDNN_ENFORCE (
474
+ platform::dynload::
475
+ cudnnFindConvolutionBackwardFilterAlgorithmEx (
476
+ handle, cudnn_input_desc, input_data,
477
+ cudnn_output_grad_desc, output_grad_data,
478
+ cudnn_conv_desc, cudnn_filter_desc, filter_grad_data,
479
+ kNUM_CUDNN_BWD_FILTER_ALGS , &returned_algo_count,
480
+ filter_perf_stat.data (), cudnn_workspace_ptr,
481
+ workspace_size_limit));
465
482
return filter_perf_stat[0 ].algo ;
466
483
});
467
484
VLOG (3 ) << " cuDNN backward filter algo " << filter_algo;
@@ -482,38 +499,42 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
482
499
workspace_size_in_bytes = std::max (workspace_size_in_bytes, tmp_size);
483
500
}
484
501
502
+ // ------------------- cudnn conv workspace ---------------------
503
+ if (!cudnn_workspace_ptr) {
504
+ cudnn_workspace =
505
+ ctx.AllocateTmpTensor <int8_t , platform::CUDADeviceContext>(
506
+ framework::make_ddim (
507
+ {static_cast <int64_t >(workspace_size_in_bytes)}),
508
+ dev_ctx);
509
+ cudnn_workspace_ptr = static_cast <void *>(cudnn_workspace.data <int8_t >());
510
+ }
511
+
485
512
// ------------------- cudnn conv backward data ---------------------
486
513
ScalingParamType<T> alpha = 1 .0f , beta = 0 .0f ;
487
514
if (input_grad) {
488
515
T* input_grad_data = input_grad->mutable_data <T>(ctx.GetPlace ());
489
516
// Because beta is zero, it is unnecessary to reset input_grad.
490
517
491
518
for (int i = 0 ; i < groups; i++) {
492
- auto cudnn_func = [&](void * cudnn_workspace) {
493
- CUDNN_ENFORCE (platform::dynload::cudnnConvolutionBackwardData (
494
- handle, &alpha, cudnn_filter_desc,
495
- filter_data + i * group_offset_filter, cudnn_output_grad_desc,
496
- output_grad_data + i * group_offset_out, cudnn_conv_desc,
497
- data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
498
- cudnn_input_desc, input_grad_data + i * group_offset_in));
499
- };
500
- workspace_handle.RunFunc (cudnn_func, workspace_size_in_bytes);
519
+ CUDNN_ENFORCE (platform::dynload::cudnnConvolutionBackwardData (
520
+ handle, &alpha, cudnn_filter_desc,
521
+ filter_data + i * group_offset_filter, cudnn_output_grad_desc,
522
+ output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
523
+ cudnn_workspace_ptr, workspace_size_in_bytes, &beta,
524
+ cudnn_input_desc, input_grad_data + i * group_offset_in));
501
525
}
502
526
}
503
527
// ------------------- cudnn conv backward filter ---------------------
504
528
if (filter_grad) {
505
529
T* filter_grad_data = filter_grad->mutable_data <T>(ctx.GetPlace ());
506
530
// Because beta is zero, it is unnecessary to reset filter_grad.
507
531
for (int i = 0 ; i < groups; i++) {
508
- auto cudnn_func = [&](void * cudnn_workspace) {
509
- CUDNN_ENFORCE (platform::dynload::cudnnConvolutionBackwardFilter (
510
- handle, &alpha, cudnn_input_desc,
511
- input_data + i * group_offset_in, cudnn_output_grad_desc,
512
- output_grad_data + i * group_offset_out, cudnn_conv_desc,
513
- filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
514
- cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
515
- };
516
- workspace_handle.RunFunc (cudnn_func, workspace_size_in_bytes);
532
+ CUDNN_ENFORCE (platform::dynload::cudnnConvolutionBackwardFilter (
533
+ handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
534
+ cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
535
+ cudnn_conv_desc, filter_algo, cudnn_workspace_ptr,
536
+ workspace_size_in_bytes, &beta, cudnn_filter_desc,
537
+ filter_grad_data + i * group_offset_filter));
517
538
}
518
539
}
519
540
}
0 commit comments