Skip to content

Commit 0f83674

Browse files
author
chengduo
authored
Merge pull request #5603 from chengduoZH/Add_conv3d_transpose_cudnn_op
add conv3d_trans_cudnn_op
2 parents 2113cbf + c359e39 commit 0f83674

File tree

11 files changed

+122
-54
lines changed

11 files changed

+122
-54
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,30 @@ function(op_library TARGET)
6161
set(pybind_flag 1)
6262
endif()
6363

64+
if ("${TARGET}" STREQUAL "compare_op")
65+
set(pybind_flag 1)
66+
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
67+
endif()
68+
69+
# conv_op contains several operators
70+
if ("${TARGET}" STREQUAL "conv_op")
71+
set(pybind_flag 1)
72+
# It's enough to just adding one operator to pybind
73+
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
74+
endif()
75+
6476
# pool_op contains several operators
6577
if ("${TARGET}" STREQUAL "pool_op")
6678
set(pybind_flag 1)
6779
# It's enough to just adding one operator to pybind
6880
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
6981
endif()
7082

71-
if ("${TARGET}" STREQUAL "compare_op")
83+
# pool_cudnn_op contains several operators
84+
if ("${TARGET}" STREQUAL "pool_cudnn_op")
7285
set(pybind_flag 1)
73-
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
86+
# It's enough to just adding one operator to pybind
87+
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
7488
endif()
7589

7690
# pool_with_index_op contains several operators
@@ -80,25 +94,18 @@ function(op_library TARGET)
8094
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
8195
endif()
8296

83-
# conv_op contains several operators
84-
if ("${TARGET}" STREQUAL "conv_op")
85-
set(pybind_flag 1)
86-
# It's enough to just adding one operator to pybind
87-
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
88-
endif()
89-
9097
# conv_transpose_op contains several operators
9198
if ("${TARGET}" STREQUAL "conv_transpose_op")
9299
set(pybind_flag 1)
93100
# It's enough to just adding one operator to pybind
94101
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n")
95102
endif()
96-
97-
# pool_cudnn_op contains several operators
98-
if ("${TARGET}" STREQUAL "pool_cudnn_op")
103+
104+
# conv_transpose_cudnn_op contains two operators
105+
if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op")
99106
set(pybind_flag 1)
100107
# It's enough to just adding one operator to pybind
101-
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
108+
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n")
102109
endif()
103110

104111
# save_restore_op contains several operators

paddle/operators/conv_cudnn_op.cu.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
226226
T alpha = 1.0f, beta = 0.0f;
227227
if (input_grad) {
228228
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
229-
auto t = framework::EigenVector<T>::Flatten(*input_grad);
230-
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
231-
t.constant(static_cast<T>(0));
229+
// Because beta is zero, it is unnecessary to reset input_grad.
230+
232231
for (int i = 0; i < groups; i++) {
233232
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
234233
handle, &alpha, cudnn_filter_desc,
@@ -241,9 +240,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
241240
// ------------------- cudnn conv backward filter ---------------------
242241
if (filter_grad) {
243242
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
244-
auto t = framework::EigenVector<T>::Flatten(*filter_grad);
245-
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
246-
t.constant(static_cast<T>(0));
243+
// Because beta is zero, it is unnecessary to reset filter_grad.
244+
247245
for (int i = 0; i < groups; i++) {
248246
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
249247
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,

paddle/operators/conv_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
225225
ops::ConvOpGrad);
226226

227227
REGISTER_OP_CPU_KERNEL(conv2d,
228-
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
228+
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
229+
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
229230
REGISTER_OP_CPU_KERNEL(
230-
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
231+
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
232+
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
231233

232234
REGISTER_OP_CPU_KERNEL(conv3d,
233-
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
235+
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
236+
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
234237
REGISTER_OP_CPU_KERNEL(
235-
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
238+
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
239+
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/conv_op.cu.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
namespace ops = paddle::operators;
1818

1919
REGISTER_OP_GPU_KERNEL(conv2d,
20-
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
20+
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
21+
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
2122
REGISTER_OP_GPU_KERNEL(
22-
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
23+
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
24+
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
2325

2426
REGISTER_OP_GPU_KERNEL(conv3d,
25-
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
27+
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
28+
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
2629
REGISTER_OP_GPU_KERNEL(
27-
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
30+
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
31+
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);

paddle/operators/conv2d_transpose_cudnn_op.cc renamed to paddle/operators/conv_transpose_cudnn_op.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,24 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
2323
framework::OpAttrChecker* op_checker)
2424
: Conv2DTransposeOpMaker(proto, op_checker) {
2525
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
26-
.SetDefault(std::vector<int>{1, 1});
26+
.SetDefault({1, 1});
27+
AddAttr<int>("workspace_size_MB",
28+
"workspace size for cudnn, in MB, "
29+
"workspace is a section of GPU memory which will be "
30+
"allocated/freed each time the operator runs, larger "
31+
"workspace size can increase performance but also requires "
32+
"better hardward. This size should be carefully setted.")
33+
.SetDefault(4096);
34+
}
35+
};
36+
37+
class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
38+
public:
39+
CudnnConv3DTransposeOpMaker(framework::OpProto* proto,
40+
framework::OpAttrChecker* op_checker)
41+
: Conv3DTransposeOpMaker(proto, op_checker) {
42+
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
43+
.SetDefault({1, 1, 1});
2744
AddAttr<int>("workspace_size_MB",
2845
"workspace size for cudnn, in MB, "
2946
"workspace is a section of GPU memory which will be "
@@ -48,3 +65,14 @@ REGISTER_OP_CPU_KERNEL(
4865
REGISTER_OP_CPU_KERNEL(
4966
conv2d_transpose_cudnn_grad,
5067
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
68+
69+
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
70+
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
71+
ops::ConvTransposeOpGrad);
72+
73+
REGISTER_OP_CPU_KERNEL(
74+
conv3d_transpose_cudnn,
75+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
76+
REGISTER_OP_CPU_KERNEL(
77+
conv3d_transpose_cudnn_grad,
78+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);

paddle/operators/conv2d_transpose_cudnn_op.cu.cc renamed to paddle/operators/conv_transpose_cudnn_op.cu.cc

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
5454
ScopedTensorDescriptor output_desc;
5555
ScopedFilterDescriptor filter_desc;
5656
ScopedConvolutionDescriptor conv_desc;
57-
DataLayout layout = DataLayout::kNCHW;
57+
DataLayout layout;
58+
59+
if (strides.size() == 2U) {
60+
layout = DataLayout::kNCHW;
61+
} else {
62+
layout = DataLayout::kNCDHW;
63+
}
5864

59-
// N, M, H, W
65+
// (N, M, H, W) or (N, M, D, H, W)
6066
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
6167
layout, framework::vectorize2int(input->dims()));
62-
// N, C, O_h, O_w
68+
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
6369
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
6470
layout, framework::vectorize2int(output->dims()));
65-
// M, C, K_h, K_w
71+
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
6672
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
6773
layout, framework::vectorize2int(filter->dims()));
6874
cudnnConvolutionDescriptor_t cudnn_conv_desc =
@@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
136142
ScopedConvolutionDescriptor conv_desc;
137143
DataLayout layout = DataLayout::kNCHW;
138144

139-
// Input: (N, M, H, W)
145+
// Input: (N, M, H, W) or (N, M, D, H, W)
140146
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
141147
layout, framework::vectorize2int(input->dims()));
142-
// Output: (N, C, O_H, O_W)
148+
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
143149
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
144150
layout, framework::vectorize2int(output_grad->dims()));
145-
// Filter (M, C, K_H, K_W)
151+
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
146152
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
147153
layout, framework::vectorize2int(filter->dims()));
148154

@@ -200,8 +206,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
200206
T alpha = 1.0f, beta = 0.0f;
201207
if (input_grad) {
202208
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
203-
math::set_constant(ctx.device_context(), input_grad, 0);
204-
209+
// Because beta is zero, it is unnecessary to reset input_grad.
205210
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
206211
handle, &alpha, cudnn_output_desc, output_grad_data,
207212
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
@@ -212,8 +217,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
212217
// ------------------- cudnn conv backward filter ---------------------
213218
if (filter_grad) {
214219
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
215-
math::set_constant(ctx.device_context(), filter_grad, 0);
216-
220+
// Because beta is zero, it is unnecessary to reset filter_grad.
217221
// Gradient with respect to the filter
218222
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
219223
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
@@ -234,3 +238,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
234238
ops::CudnnConvTransposeOpKernel<float>);
235239
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
236240
ops::CudnnConvTransposeGradOpKernel<float>);
241+
242+
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
243+
ops::CudnnConvTransposeOpKernel<float>);
244+
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
245+
ops::CudnnConvTransposeGradOpKernel<float>);

paddle/operators/conv_transpose_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
185185

186186
REGISTER_OP_CPU_KERNEL(
187187
conv2d_transpose,
188-
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
188+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
189+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
189190
REGISTER_OP_CPU_KERNEL(
190191
conv2d_transpose_grad,
191-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
192+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
193+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
192194

193195
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
194196
conv3d_transpose_grad, ops::ConvTransposeOpGrad);
195197

196198
REGISTER_OP_CPU_KERNEL(
197199
conv3d_transpose,
198-
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
200+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
201+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
199202
REGISTER_OP_CPU_KERNEL(
200203
conv3d_transpose_grad,
201-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
204+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
205+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/conv_transpose_op.cu.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@ namespace ops = paddle::operators;
1818

1919
REGISTER_OP_GPU_KERNEL(
2020
conv2d_transpose,
21-
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
21+
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
22+
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
2223
REGISTER_OP_GPU_KERNEL(
2324
conv2d_transpose_grad,
24-
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
25+
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
26+
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
2527

2628
REGISTER_OP_GPU_KERNEL(
2729
conv3d_transpose,
28-
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
30+
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
31+
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
2932
REGISTER_OP_GPU_KERNEL(
3033
conv3d_transpose_grad,
31-
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
34+
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
35+
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);

paddle/operators/pool_cudnn_op.cu.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
135135

136136
if (input_grad) {
137137
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
138-
math::SetConstant<paddle::platform::GPUPlace, T> set_zero;
139-
set_zero(ctx.device_context(), input_grad, static_cast<T>(0));
138+
// Because beta is zero, it is unnecessary to reset input_grad.
140139

141140
PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward(
142141
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,

paddle/platform/cudnn_helper.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
6363
} \
6464
} while (false)
6565

66-
enum class DataLayout {
66+
enum class DataLayout { // Not use
6767
kNHWC,
6868
kNCHW,
69+
kNCDHW,
6970
kNCHW_VECT_C,
7071
};
7172

@@ -107,12 +108,15 @@ class CudnnDataType<double> {
107108
}
108109
};
109110

110-
inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
111+
inline cudnnTensorFormat_t GetCudnnTensorFormat(
112+
const DataLayout& order) { // Not use
111113
switch (order) {
112114
case DataLayout::kNHWC:
113115
return CUDNN_TENSOR_NHWC;
114116
case DataLayout::kNCHW:
115117
return CUDNN_TENSOR_NCHW;
118+
case DataLayout::kNCDHW:
119+
return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW
116120
default:
117121
PADDLE_THROW("Unknown cudnn equivalent for order");
118122
}
@@ -139,7 +143,7 @@ class ScopedTensorDescriptor {
139143
strides[i] = dims[i + 1] * strides[i + 1];
140144
}
141145
// Update tensor descriptor dims setting if groups > 1
142-
// FIXME(typhoonzero): Assume using NCHW order
146+
// FIXME(typhoonzero): Assume using NCHW or NCDHW order
143147
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
144148
if (groups > 1) {
145149
dims_with_group[1] = dims_with_group[1] / groups;
@@ -176,9 +180,10 @@ class ScopedFilterDescriptor {
176180
const cudnnDataType_t type,
177181
const std::vector<int>& kernel,
178182
const int groups = 1) {
179-
// filter layout: MCHW, where M is the number of
183+
// filter layout: MCHW(MCDHW), where M is the number of
180184
// output image channels, C is the number of input image channels,
181-
// H and W is height and width of filter.
185+
// D is the depth of the filter, H is the height of the filter, and W is the
186+
// width of the filter.
182187
std::vector<int> kernel_with_group(kernel.begin(), kernel.end());
183188
if (groups > 1) {
184189
// M /= groups

0 commit comments

Comments
 (0)