Skip to content

Commit ca743de

Browse files
committed
enable more type for splitOp and ConcatOp
1 parent 431491a commit ca743de

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

paddle/fluid/operators/concat_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,13 @@ REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
107107
false> /* set false to disable empty grad */);
108108
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad);
109109
REGISTER_OP_CPU_KERNEL(
110-
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>);
110+
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
111+
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
112+
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
113+
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>);
111114
REGISTER_OP_CPU_KERNEL(
112115
concat_grad,
113-
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>);
116+
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
117+
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>,
118+
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
119+
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>);

paddle/fluid/operators/concat_op.cu.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/concat_op.h"
1616
namespace ops = paddle::operators;
1717
REGISTER_OP_CUDA_KERNEL(
18-
concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>);
18+
concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, double>,
19+
ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>,
20+
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
21+
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>);
1922
REGISTER_OP_CUDA_KERNEL(
2023
concat_grad,
21-
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>);
24+
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
25+
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>,
26+
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
27+
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>);

paddle/fluid/operators/split_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,7 @@ USE_CPU_ONLY_OP(concat);
115115

116116
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker);
117117
REGISTER_OP_CPU_KERNEL(split,
118-
ops::SplitOpKernel<paddle::platform::CPUPlace, float>);
118+
ops::SplitOpKernel<paddle::platform::CPUPlace, double>,
119+
ops::SplitOpKernel<paddle::platform::CPUPlace, float>,
120+
ops::SplitOpKernel<paddle::platform::CPUPlace, int64_t>,
121+
ops::SplitOpKernel<paddle::platform::CPUPlace, int>);

paddle/fluid/operators/split_op.cu.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/split_op.h"
1616
namespace ops = paddle::operators;
1717
REGISTER_OP_CUDA_KERNEL(
18-
split, ops::SplitOpKernel<paddle::platform::CUDADeviceContext, float>);
18+
split, ops::SplitOpKernel<paddle::platform::CUDADeviceContext, double>,
19+
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, float>,
20+
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
21+
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, int>);

0 commit comments

Comments
 (0)