Skip to content

Commit 7c3ec22

Browse files
authored
"fix gpu related op registered" (#5647)
1 parent 9f28925 commit 7c3ec22

File tree

4 files changed

+30
-8
lines changed

4 files changed

+30
-8
lines changed

paddle/operators/elementwise_add_op.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ namespace ops = paddle::operators;
1919

2020
REGISTER_OP_GPU_KERNEL(
2121
elementwise_add,
22-
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, float>);
22+
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, float>,
23+
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, double>,
24+
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, int>,
25+
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, int64_t>);
2326
REGISTER_OP_GPU_KERNEL(
2427
elementwise_add_grad,
25-
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, float>);
28+
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, float>,
29+
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, double>,
30+
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, int>,
31+
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, int64_t>);

paddle/operators/elementwise_div_op.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ namespace ops = paddle::operators;
1919

2020
REGISTER_OP_GPU_KERNEL(
2121
elementwise_div,
22-
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, float>);
22+
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, float>,
23+
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, double>,
24+
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, int>,
25+
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, int64_t>);
2326
REGISTER_OP_GPU_KERNEL(
2427
elementwise_div_grad,
25-
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, float>);
28+
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, float>,
29+
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, double>,
30+
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, int>,
31+
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, int64_t>);

paddle/operators/elementwise_mul_op.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@ namespace ops = paddle::operators;
2020
REGISTER_OP_GPU_KERNEL(
2121
elementwise_mul,
2222
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>,
23-
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>);
23+
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>,
24+
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, int>,
25+
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, int64_t>);
2426
REGISTER_OP_GPU_KERNEL(
2527
elementwise_mul_grad,
2628
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>,
27-
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>);
29+
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>,
30+
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, int>,
31+
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, int64_t>);

paddle/operators/elementwise_sub_op.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ namespace ops = paddle::operators;
1919

2020
REGISTER_OP_GPU_KERNEL(
2121
elementwise_sub,
22-
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, float>);
22+
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, float>,
23+
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, double>,
24+
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, int>,
25+
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, int64_t>);
2326
REGISTER_OP_GPU_KERNEL(
2427
elementwise_sub_grad,
25-
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, float>);
28+
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, float>,
29+
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, double>,
30+
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, int>,
31+
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, int64_t>);

0 commit comments

Comments
 (0)