Skip to content

Commit b73708d

Browse files
author
chengduo
authored
add int and int64 dtype for gather_op (#14175)
test=develop
1 parent 62a0fe0 commit b73708d

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

paddle/fluid/operators/gather_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
102102
paddle::framework::DefaultGradOpDescMaker<true>);
103103
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp);
104104
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
105-
ops::GatherOpKernel<int>, ops::GatherOpKernel<double>);
105+
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
106+
ops::GatherOpKernel<int64_t>);
106107
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
108+
ops::GatherGradientOpKernel<double>,
107109
ops::GatherGradientOpKernel<int>,
108-
ops::GatherGradientOpKernel<double>);
110+
ops::GatherGradientOpKernel<int64_t>);

paddle/fluid/operators/gather_op.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
6161
} // namespace paddle
6262

6363
namespace ops = paddle::operators;
64-
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
65-
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
64+
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
65+
ops::GatherOpCUDAKernel<double>,
66+
ops::GatherOpCUDAKernel<int64_t>,
67+
ops::GatherOpCUDAKernel<int>);
68+
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
69+
ops::GatherGradOpCUDAKernel<double>,
70+
ops::GatherGradOpCUDAKernel<int64_t>,
71+
ops::GatherGradOpCUDAKernel<int>);

0 commit comments

Comments
 (0)