Skip to content

Commit fafd3e0

Browse files
committed
Merge branch 'develop' into softsign
2 parents dffa8fa + 134eaf2 commit fafd3e0

21 files changed

+745
-190
lines changed

paddle/operators/conv_cudnn_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad,
4040
ops::ConvOpGrad);
4141

4242
REGISTER_OP_CPU_KERNEL(conv_cudnn,
43-
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
43+
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
44+
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
4445
REGISTER_OP_CPU_KERNEL(
45-
conv_cudnn_grad,
46-
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
46+
conv_cudnn_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
47+
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/conv_cudnn_op.cu.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
259259
} // namespace operators
260260
} // namespace paddle
261261

262-
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>);
262+
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>,
263+
paddle::operators::CudnnConvOpKernel<double>);
263264
REGISTER_OP_GPU_KERNEL(conv_cudnn_grad,
264-
paddle::operators::CudnnConvGradOpKernel<float>);
265+
paddle::operators::CudnnConvGradOpKernel<float>,
266+
paddle::operators::CudnnConvGradOpKernel<double>);

paddle/operators/conv_transpose_cudnn_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,22 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
6161

6262
REGISTER_OP_CPU_KERNEL(
6363
conv2d_transpose_cudnn,
64-
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
64+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
65+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
6566
REGISTER_OP_CPU_KERNEL(
6667
conv2d_transpose_cudnn_grad,
67-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
68+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
69+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
6870

6971
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
7072
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
7173
ops::ConvTransposeOpGrad);
7274

7375
REGISTER_OP_CPU_KERNEL(
7476
conv3d_transpose_cudnn,
75-
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
77+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
78+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
7679
REGISTER_OP_CPU_KERNEL(
7780
conv3d_transpose_cudnn_grad,
78-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
81+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
82+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/conv_transpose_cudnn_op.cu.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,15 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
235235
namespace ops = paddle::operators;
236236

237237
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
238-
ops::CudnnConvTransposeOpKernel<float>);
238+
ops::CudnnConvTransposeOpKernel<float>,
239+
ops::CudnnConvTransposeOpKernel<double>);
239240
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
240-
ops::CudnnConvTransposeGradOpKernel<float>);
241+
ops::CudnnConvTransposeGradOpKernel<float>,
242+
ops::CudnnConvTransposeGradOpKernel<double>);
241243

242244
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
243-
ops::CudnnConvTransposeOpKernel<float>);
245+
ops::CudnnConvTransposeOpKernel<float>,
246+
ops::CudnnConvTransposeOpKernel<double>);
244247
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
245-
ops::CudnnConvTransposeGradOpKernel<float>);
248+
ops::CudnnConvTransposeGradOpKernel<float>,
249+
ops::CudnnConvTransposeGradOpKernel<double>);

paddle/operators/pool_cudnn_op.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad,
2020
ops::PoolOpGrad);
2121

2222
REGISTER_OP_CPU_KERNEL(pool2d_cudnn,
23-
ops::PoolKernel<paddle::platform::CPUPlace, float>);
23+
ops::PoolKernel<paddle::platform::CPUPlace, float>,
24+
ops::PoolKernel<paddle::platform::CPUPlace, double>);
2425
REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad,
25-
ops::PoolGradKernel<paddle::platform::CPUPlace, float>)
26+
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
27+
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
28+
29+
REGISTER_OP(pool3d_cudnn, ops::PoolOp, ops::Pool3dOpMaker, pool3d_cudnn_grad,
30+
ops::PoolOpGrad);
31+
32+
REGISTER_OP_CPU_KERNEL(pool3d_cudnn,
33+
ops::PoolKernel<paddle::platform::CPUPlace, float>,
34+
ops::PoolKernel<paddle::platform::CPUPlace, double>);
35+
REGISTER_OP_CPU_KERNEL(pool3d_cudnn_grad,
36+
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
37+
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)

paddle/operators/pool_cudnn_op.cu.cc

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
5252
ScopedTensorDescriptor input_desc;
5353
ScopedTensorDescriptor output_desc;
5454
ScopedPoolingDescriptor pool_desc;
55-
DataLayout layout = DataLayout::kNCHW;
55+
DataLayout layout;
56+
57+
if (strides.size() == 2U) {
58+
layout = DataLayout::kNCHW;
59+
} else {
60+
layout = DataLayout::kNCDHW;
61+
}
5662

5763
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
5864
layout, framework::vectorize2int(input->dims()));
@@ -112,7 +118,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
112118
ScopedTensorDescriptor input_desc;
113119
ScopedTensorDescriptor output_desc;
114120
ScopedPoolingDescriptor pool_desc;
115-
DataLayout layout = DataLayout::kNCHW;
121+
DataLayout layout;
122+
123+
if (strides.size() == 2U) {
124+
layout = DataLayout::kNCHW;
125+
} else {
126+
layout = DataLayout::kNCDHW;
127+
}
116128

117129
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
118130
layout, framework::vectorize2int(input->dims()));
@@ -150,5 +162,12 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
150162

151163
namespace ops = paddle::operators;
152164

153-
REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>);
154-
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>);
165+
REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>,
166+
ops::PoolCudnnOpKernel<double>);
167+
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
168+
ops::PoolCudnnGradOpKernel<double>);
169+
170+
REGISTER_OP_GPU_KERNEL(pool3d_cudnn, ops::PoolCudnnOpKernel<float>,
171+
ops::PoolCudnnOpKernel<double>);
172+
REGISTER_OP_GPU_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
173+
ops::PoolCudnnGradOpKernel<double>);

paddle/operators/pool_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,18 @@ REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad,
217217
ops::PoolOpGrad);
218218

219219
REGISTER_OP_CPU_KERNEL(pool2d,
220-
ops::PoolKernel<paddle::platform::CPUPlace, float>);
220+
ops::PoolKernel<paddle::platform::CPUPlace, float>,
221+
ops::PoolKernel<paddle::platform::CPUPlace, double>);
221222
REGISTER_OP_CPU_KERNEL(pool2d_grad,
222-
ops::PoolGradKernel<paddle::platform::CPUPlace, float>)
223+
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
224+
ops::PoolGradKernel<paddle::platform::CPUPlace, double>)
223225

224226
REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad,
225227
ops::PoolOpGrad);
226228

227229
REGISTER_OP_CPU_KERNEL(pool3d,
228-
ops::PoolKernel<paddle::platform::CPUPlace, float>);
230+
ops::PoolKernel<paddle::platform::CPUPlace, float>,
231+
ops::PoolKernel<paddle::platform::CPUPlace, double>);
229232
REGISTER_OP_CPU_KERNEL(pool3d_grad,
230-
ops::PoolGradKernel<paddle::platform::CPUPlace, float>);
233+
ops::PoolGradKernel<paddle::platform::CPUPlace, float>,
234+
ops::PoolGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/pool_op.cu.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ limitations under the License. */
1717
namespace ops = paddle::operators;
1818

1919
REGISTER_OP_GPU_KERNEL(pool2d,
20-
ops::PoolKernel<paddle::platform::GPUPlace, float>);
20+
ops::PoolKernel<paddle::platform::GPUPlace, float>,
21+
ops::PoolKernel<paddle::platform::GPUPlace, double>);
2122
REGISTER_OP_GPU_KERNEL(pool2d_grad,
22-
ops::PoolGradKernel<paddle::platform::GPUPlace, float>);
23+
ops::PoolGradKernel<paddle::platform::GPUPlace, float>,
24+
ops::PoolGradKernel<paddle::platform::GPUPlace, double>);
2325

2426
REGISTER_OP_GPU_KERNEL(pool3d,
25-
ops::PoolKernel<paddle::platform::GPUPlace, float>);
27+
ops::PoolKernel<paddle::platform::GPUPlace, float>,
28+
ops::PoolKernel<paddle::platform::GPUPlace, double>);
2629
REGISTER_OP_GPU_KERNEL(pool3d_grad,
27-
ops::PoolGradKernel<paddle::platform::GPUPlace, float>);
30+
ops::PoolGradKernel<paddle::platform::GPUPlace, float>,
31+
ops::PoolGradKernel<paddle::platform::GPUPlace, double>);

paddle/operators/pool_with_index_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,22 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
250250

251251
REGISTER_OP_CPU_KERNEL(
252252
max_pool2d_with_index,
253-
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
253+
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>,
254+
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, double>);
254255
REGISTER_OP_CPU_KERNEL(
255256
max_pool2d_with_index_grad,
256-
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
257+
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>,
258+
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, double>)
257259

258260
REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
259261
ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad,
260262
ops::MaxPoolWithIndexOpGrad);
261263

262264
REGISTER_OP_CPU_KERNEL(
263265
max_pool3d_with_index,
264-
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
266+
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>,
267+
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, double>);
265268
REGISTER_OP_CPU_KERNEL(
266269
max_pool3d_with_index_grad,
267-
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
270+
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>,
271+
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, double>)

paddle/operators/pool_with_index_op.cu.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@ namespace ops = paddle::operators;
1818

1919
REGISTER_OP_GPU_KERNEL(
2020
max_pool2d_with_index,
21-
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
21+
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>,
22+
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, double>);
2223
REGISTER_OP_GPU_KERNEL(
2324
max_pool2d_with_index_grad,
24-
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
25+
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>,
26+
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, double>)
2527

2628
REGISTER_OP_GPU_KERNEL(
2729
max_pool3d_with_index,
28-
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
30+
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>,
31+
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, double>);
2932
REGISTER_OP_GPU_KERNEL(
3033
max_pool3d_with_index_grad,
31-
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
34+
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>,
35+
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, double>)

0 commit comments

Comments
 (0)