@@ -56,26 +56,56 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
56
56
ScopedFilterDescriptor filter_desc;
57
57
ScopedConvolutionDescriptor conv_desc;
58
58
DataLayout layout = DataLayout::kNCHW ;
59
+ if (input->dims ().size () == 5 ) {
60
+ layout = DataLayout::kNCDHW ;
61
+ }
62
+
63
+ cudnnConvolutionDescriptor_t cudnn_conv_desc =
64
+ conv_desc.descriptor <T>(paddings, strides, dilations);
65
+
66
+ #if CUDNN_VERSION_MIN(7, 0, 0)
67
+ // cudnn 7 can support groups, no need to do it mannually
68
+ // FIXME(typhoonzero): find a better way to disable groups
69
+ // rather than setting it to 1.
70
+ PADDLE_ENFORCE (platform::dynload::cudnnSetConvolutionGroupCount (
71
+ cudnn_conv_desc, groups));
72
+ groups = 1 ;
73
+ #endif
59
74
60
75
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
61
76
layout, framework::vectorize2int (input->dims ()), groups);
62
77
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor <T>(
63
78
layout, framework::vectorize2int (output->dims ()), groups);
64
79
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor <T>(
65
80
layout, framework::vectorize2int (filter->dims ()), groups);
66
- cudnnConvolutionDescriptor_t cudnn_conv_desc =
67
- conv_desc.descriptor <T>(paddings, strides, dilations);
68
81
69
82
int input_channels = input->dims ()[1 ];
70
- int input_height = input->dims ()[2 ];
71
- int input_width = input->dims ()[3 ];
72
- int output_channels = output->dims ()[1 ];
73
- int output_height = output->dims ()[2 ];
74
- int output_width = output->dims ()[3 ];
83
+ int input_height, input_width, input_depth;
84
+ if (input->dims ().size () == 5 ) {
85
+ input_depth = input->dims ()[2 ];
86
+ input_height = input->dims ()[3 ];
87
+ input_width = input->dims ()[4 ];
88
+ } else { // dim size is enforced in InferShape
89
+ input_depth = 1 ;
90
+ input_height = input->dims ()[2 ];
91
+ input_width = input->dims ()[3 ];
92
+ }
93
+ int output_channels = filter->dims ()[0 ];
94
+ int output_height, output_width, output_depth;
95
+ if (output->dims ().size () == 5 ) {
96
+ output_depth = output->dims ()[2 ];
97
+ output_height = output->dims ()[3 ];
98
+ output_width = output->dims ()[4 ];
99
+ } else {
100
+ output_depth = 1 ;
101
+ output_height = output->dims ()[2 ];
102
+ output_width = output->dims ()[3 ];
103
+ }
75
104
76
- int group_offset_in = input_channels / groups * input_height * input_width;
105
+ int group_offset_in =
106
+ input_channels / groups * input_height * input_width * input_depth;
77
107
int group_offset_out =
78
- output_channels / groups * output_height * output_width;
108
+ output_channels / groups * output_height * output_width * output_depth ;
79
109
int group_offset_filter = filter->numel () / groups;
80
110
// ------------------- cudnn conv workspace ---------------------
81
111
void * cudnn_workspace = nullptr ;
@@ -138,12 +168,26 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
138
168
// ------------------- cudnn descriptors ---------------------
139
169
ScopedTensorDescriptor input_desc;
140
170
ScopedTensorDescriptor output_grad_desc;
141
- ScopedTensorDescriptor input_grad_desc;
142
171
143
172
ScopedFilterDescriptor filter_desc;
144
173
ScopedFilterDescriptor filter_grad_desc;
145
174
ScopedConvolutionDescriptor conv_desc;
146
175
DataLayout layout = DataLayout::kNCHW ;
176
+ if (input->dims ().size () == 5 ) {
177
+ layout = DataLayout::kNCDHW ;
178
+ }
179
+
180
+ cudnnConvolutionDescriptor_t cudnn_conv_desc =
181
+ conv_desc.descriptor <T>(paddings, strides, dilations);
182
+
183
+ #if CUDNN_VERSION_MIN(7, 0, 0)
184
+ // cudnn 7 can support groups, no need to do it mannually
185
+ // FIXME(typhoonzero): find a better way to disable groups
186
+ // rather than setting it to 1.
187
+ PADDLE_ENFORCE (platform::dynload::cudnnSetConvolutionGroupCount (
188
+ cudnn_conv_desc, groups));
189
+ groups = 1 ;
190
+ #endif
147
191
148
192
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
149
193
layout, framework::vectorize2int (input->dims ()), groups);
@@ -152,22 +196,35 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
152
196
layout, framework::vectorize2int (output_grad->dims ()), groups);
153
197
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor <T>(
154
198
layout, framework::vectorize2int (filter->dims ()), groups);
155
- cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr ;
156
- cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr ;
157
-
158
- cudnnConvolutionDescriptor_t cudnn_conv_desc =
159
- conv_desc.descriptor <T>(paddings, strides, dilations);
160
199
161
200
int input_channels = input->dims ()[1 ];
162
- int input_height = input->dims ()[2 ];
163
- int input_width = input->dims ()[3 ];
201
+ int input_height, input_width, input_depth;
202
+ if (input->dims ().size () == 5 ) {
203
+ input_depth = input->dims ()[2 ];
204
+ input_height = input->dims ()[3 ];
205
+ input_width = input->dims ()[4 ];
206
+ } else { // dim size is enforced in InferShape
207
+ input_depth = 1 ;
208
+ input_height = input->dims ()[2 ];
209
+ input_width = input->dims ()[3 ];
210
+ }
211
+
164
212
int output_grad_channels = filter->dims ()[0 ];
165
- int output_grad_height = output_grad->dims ()[2 ];
166
- int output_grad_width = output_grad->dims ()[3 ];
213
+ int output_grad_height, output_grad_width, output_grad_depth;
214
+ if (input->dims ().size () == 5 ) {
215
+ output_grad_depth = output_grad->dims ()[2 ];
216
+ output_grad_height = output_grad->dims ()[3 ];
217
+ output_grad_width = output_grad->dims ()[4 ];
218
+ } else {
219
+ output_grad_depth = 1 ;
220
+ output_grad_height = output_grad->dims ()[2 ];
221
+ output_grad_width = output_grad->dims ()[3 ];
222
+ }
167
223
168
- int group_offset_in = input_channels / groups * input_height * input_width;
169
- int group_offset_out =
170
- output_grad_channels / groups * output_grad_height * output_grad_width;
224
+ int group_offset_in =
225
+ input_channels / groups * input_height * input_width * input_depth;
226
+ int group_offset_out = output_grad_channels / groups * output_grad_height *
227
+ output_grad_width * output_grad_depth;
171
228
int group_offset_filter = filter->numel () / groups;
172
229
// ------------------- cudnn backward algorithm ---------------------
173
230
cudnnConvolutionBwdDataAlgo_t data_algo;
@@ -180,8 +237,6 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
180
237
181
238
auto handle = ctx.cuda_device_context ().cudnn_handle ();
182
239
if (input_grad) {
183
- cudnn_input_grad_desc = input_grad_desc.descriptor <T>(
184
- layout, framework::vectorize2int (input_grad->dims ()), groups);
185
240
PADDLE_ENFORCE (
186
241
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm (
187
242
handle, cudnn_filter_desc,
@@ -190,19 +245,17 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
190
245
cudnn_output_grad_desc, cudnn_conv_desc,
191
246
// dxDesc: Handle to the previously initialized output tensor
192
247
// descriptor.
193
- cudnn_input_grad_desc ,
248
+ cudnn_input_desc ,
194
249
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
195
250
workspace_size_limit, &data_algo));
196
251
PADDLE_ENFORCE (
197
252
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize (
198
253
handle, cudnn_filter_desc, cudnn_output_grad_desc,
199
- cudnn_conv_desc, cudnn_input_grad_desc , data_algo, &tmp_size));
254
+ cudnn_conv_desc, cudnn_input_desc , data_algo, &tmp_size));
200
255
workspace_size_in_bytes = std::max (workspace_size_in_bytes, tmp_size);
201
256
}
202
257
203
258
if (filter_grad) {
204
- cudnn_filter_grad_desc = filter_grad_desc.descriptor <T>(
205
- layout, framework::vectorize2int (filter_grad->dims ()), groups);
206
259
PADDLE_ENFORCE (
207
260
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm (
208
261
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
@@ -222,7 +275,6 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
222
275
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace ());
223
276
cudnn_workspace = paddle::memory::Alloc (gpu, workspace_size_in_bytes);
224
277
// ------------------- cudnn conv backward data ---------------------
225
- // FIXME(typhoonzero): template type T may not be the same as cudnn call.
226
278
T alpha = 1 .0f , beta = 0 .0f ;
227
279
if (input_grad) {
228
280
T* input_grad_data = input_grad->mutable_data <T>(ctx.GetPlace ());
@@ -233,21 +285,20 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
233
285
handle, &alpha, cudnn_filter_desc,
234
286
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
235
287
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
236
- cudnn_workspace, workspace_size_in_bytes, &beta,
237
- cudnn_input_grad_desc, input_grad_data + i * group_offset_in));
288
+ cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
289
+ input_grad_data + i * group_offset_in));
238
290
}
239
291
}
240
292
// ------------------- cudnn conv backward filter ---------------------
241
293
if (filter_grad) {
242
294
T* filter_grad_data = filter_grad->mutable_data <T>(ctx.GetPlace ());
243
295
// Because beta is zero, it is unnecessary to reset filter_grad.
244
-
245
296
for (int i = 0 ; i < groups; i++) {
246
297
PADDLE_ENFORCE (platform::dynload::cudnnConvolutionBackwardFilter (
247
298
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
248
299
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
249
300
cudnn_conv_desc, filter_algo, cudnn_workspace,
250
- workspace_size_in_bytes, &beta, cudnn_filter_grad_desc ,
301
+ workspace_size_in_bytes, &beta, cudnn_filter_desc ,
251
302
filter_grad_data + i * group_offset_filter));
252
303
}
253
304
}
@@ -259,8 +310,16 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
259
310
} // namespace operators
260
311
} // namespace paddle
261
312
262
- REGISTER_OP_GPU_KERNEL (conv_cudnn, paddle::operators::CudnnConvOpKernel<float >,
313
+ REGISTER_OP_GPU_KERNEL (conv2d_cudnn,
314
+ paddle::operators::CudnnConvOpKernel<float >,
315
+ paddle::operators::CudnnConvOpKernel<double >);
316
+ REGISTER_OP_GPU_KERNEL (conv2d_cudnn_grad,
317
+ paddle::operators::CudnnConvGradOpKernel<float >,
318
+ paddle::operators::CudnnConvGradOpKernel<double >);
319
+
320
+ REGISTER_OP_GPU_KERNEL (conv3d_cudnn,
321
+ paddle::operators::CudnnConvOpKernel<float >,
263
322
paddle::operators::CudnnConvOpKernel<double >);
264
- REGISTER_OP_GPU_KERNEL (conv_cudnn_grad ,
323
+ REGISTER_OP_GPU_KERNEL (conv3d_cudnn_grad ,
265
324
paddle::operators::CudnnConvGradOpKernel<float >,
266
325
paddle::operators::CudnnConvGradOpKernel<double >);
0 commit comments