Skip to content

Commit c359e39

Browse files
committed
add double type kernel
1 parent 0bc2f41 commit c359e39

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

paddle/operators/conv_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
225225
ops::ConvOpGrad);
226226

227227
REGISTER_OP_CPU_KERNEL(conv2d,
228-
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
228+
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
229+
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
229230
REGISTER_OP_CPU_KERNEL(
230-
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
231+
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
232+
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
231233

232234
REGISTER_OP_CPU_KERNEL(conv3d,
233-
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
235+
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
236+
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
234237
REGISTER_OP_CPU_KERNEL(
235-
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
238+
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
239+
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/conv_op.cu.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
namespace ops = paddle::operators;
1818

1919
REGISTER_OP_GPU_KERNEL(conv2d,
20-
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
20+
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
21+
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
2122
REGISTER_OP_GPU_KERNEL(
22-
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
23+
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
24+
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
2325

2426
REGISTER_OP_GPU_KERNEL(conv3d,
25-
ops::GemmConvKernel<paddle::platform::GPUPlace, float>);
27+
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
28+
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
2629
REGISTER_OP_GPU_KERNEL(
27-
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>);
30+
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
31+
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);

paddle/operators/conv_transpose_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
185185

186186
REGISTER_OP_CPU_KERNEL(
187187
conv2d_transpose,
188-
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
188+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
189+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
189190
REGISTER_OP_CPU_KERNEL(
190191
conv2d_transpose_grad,
191-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
192+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
193+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
192194

193195
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
194196
conv3d_transpose_grad, ops::ConvTransposeOpGrad);
195197

196198
REGISTER_OP_CPU_KERNEL(
197199
conv3d_transpose,
198-
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
200+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
201+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
199202
REGISTER_OP_CPU_KERNEL(
200203
conv3d_transpose_grad,
201-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
204+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
205+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/conv_transpose_op.cu.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@ namespace ops = paddle::operators;
1818

1919
REGISTER_OP_GPU_KERNEL(
2020
conv2d_transpose,
21-
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
21+
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
22+
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
2223
REGISTER_OP_GPU_KERNEL(
2324
conv2d_transpose_grad,
24-
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
25+
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
26+
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
2527

2628
REGISTER_OP_GPU_KERNEL(
2729
conv3d_transpose,
28-
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
30+
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
31+
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
2932
REGISTER_OP_GPU_KERNEL(
3033
conv3d_transpose_grad,
31-
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
34+
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
35+
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);

0 commit comments

Comments
 (0)