Skip to content

Commit bd29437

Browse files
author
Yibing Liu
authored
Fix gather & stack op (#14355)
* Add int type support for stack_op * Improve gather op to support index with shape N x 1 test=develop * Fix stack_op kernel's registry test=develop
1 parent 9d4425d commit bd29437

File tree

7 files changed

+25
-10
lines changed

7 files changed

+25
-10
lines changed

paddle/fluid/operators/gather.cu.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
5050
const Tensor& index, Tensor* output) {
5151
// PADDLE_ENFORCE(platform::is_gpu_place(place));
5252
// check index of shape 1-D
53-
PADDLE_ENFORCE(index.dims().size() == 1);
53+
PADDLE_ENFORCE(index.dims().size() == 1 ||
54+
(index.dims().size() == 2 && index.dims()[1] == 1));
55+
5456
int index_size = index.dims()[0];
5557

5658
auto src_dims = src.dims();

paddle/fluid/operators/gather.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
3838
const Tensor& index, Tensor* output) {
3939
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
4040
// check index of shape 1-D
41-
PADDLE_ENFORCE(index.dims().size() == 1);
41+
PADDLE_ENFORCE(index.dims().size() == 1 ||
42+
(index.dims().size() == 2 && index.dims()[1] == 1));
4243
int64_t index_size = index.dims()[0];
4344

4445
auto src_dims = src.dims();

paddle/fluid/operators/gather_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class GatherOp : public framework::OperatorWithKernel {
3131
"Output(Out) of GatherOp should not be null.");
3232

3333
auto index_dims = ctx->GetInputDim("Index");
34-
PADDLE_ENFORCE(index_dims.size() == 1);
34+
PADDLE_ENFORCE(index_dims.size() == 1 ||
35+
(index_dims.size() == 2 && index_dims[1] == 1));
3536
int batch_size = ctx->GetInputDim("Index")[0];
3637
framework::DDim output_dims(ctx->GetInputDim("X"));
3738
output_dims[0] = batch_size;
@@ -53,6 +54,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
5354

5455
void InferShape(framework::InferShapeContext* ctx) const override {
5556
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
57+
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
5658
}
5759

5860
protected:
@@ -75,7 +77,7 @@ Gather Operator.
7577
7678
$Out = X[Index]$
7779
78-
Out is obtained by gathering entries of the outer-most dimension
80+
Out is obtained by gathering entries of the outer-most dimension
7981
of X indexed by Index and concatenate them together.
8082
8183
Example:

paddle/fluid/operators/scatter.cu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
5151
const Tensor& index, Tensor* output) {
5252
// PADDLE_ENFORCE(platform::is_gpu_place(place));
5353
// check index of shape 1-D
54-
PADDLE_ENFORCE(index.dims().size() == 1);
54+
PADDLE_ENFORCE(index.dims().size() == 1 ||
55+
(index.dims().size() == 2 && index.dims()[1] == 1));
5556
int index_size = index.dims()[0];
5657

5758
auto src_dims = src.dims();

paddle/fluid/operators/scatter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
3737
const Tensor& index, Tensor* output) {
3838
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
3939
// check index of shape 1-D
40-
PADDLE_ENFORCE(index.dims().size() == 1);
40+
PADDLE_ENFORCE(index.dims().size() == 1 ||
41+
(index.dims().size() == 2 && index.dims()[1] == 1));
4142
int index_size = index.dims()[0];
4243

4344
auto src_dims = src.dims();

paddle/fluid/operators/stack_op.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@ REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
2121
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
2222

2323
REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>,
24-
ops::StackKernel<plat::CPUDeviceContext, double>);
24+
ops::StackKernel<plat::CPUDeviceContext, double>,
25+
ops::StackKernel<plat::CPUDeviceContext, int>,
26+
ops::StackKernel<plat::CPUDeviceContext, int64_t>);
2527

2628
REGISTER_OP_CPU_KERNEL(stack_grad,
2729
ops::StackGradKernel<plat::CPUDeviceContext, float>,
28-
ops::StackGradKernel<plat::CPUDeviceContext, double>);
30+
ops::StackGradKernel<plat::CPUDeviceContext, double>,
31+
ops::StackGradKernel<plat::CPUDeviceContext, int>,
32+
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>);

paddle/fluid/operators/stack_op.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ namespace plat = paddle::platform;
1818
namespace ops = paddle::operators;
1919

2020
REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel<plat::CUDADeviceContext, float>,
21-
ops::StackKernel<plat::CUDADeviceContext, double>);
21+
ops::StackKernel<plat::CUDADeviceContext, double>,
22+
ops::StackKernel<plat::CUDADeviceContext, int>,
23+
ops::StackKernel<plat::CUDADeviceContext, int64_t>);
2224

2325
REGISTER_OP_CUDA_KERNEL(stack_grad,
2426
ops::StackGradKernel<plat::CUDADeviceContext, float>,
25-
ops::StackGradKernel<plat::CUDADeviceContext, double>);
27+
ops::StackGradKernel<plat::CUDADeviceContext, double>,
28+
ops::StackGradKernel<plat::CUDADeviceContext, int>,
29+
ops::StackGradKernel<plat::CUDADeviceContext, int64_t>);

0 commit comments

Comments
 (0)