Skip to content

Commit e0bca5f

Browse files
cjldqingqing01
authored andcommitted
Implement slice grad operator. #8130 (#12330)
* Implement slice grad operator. #8130 * test slice grad operator and bug fix * Fix pre commit style
1 parent 03dc7b7 commit e0bca5f

File tree

4 files changed

+132
-8
lines changed

4 files changed

+132
-8
lines changed

paddle/fluid/operators/slice_op.cc

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel {
2525
public:
2626
using framework::OperatorWithKernel::OperatorWithKernel;
2727

28-
void InferShape(framework::InferShapeContext *ctx) const override {
28+
void InferShape(framework::InferShapeContext* ctx) const override {
2929
PADDLE_ENFORCE(ctx->HasInput("Input"),
3030
"Input (Input) of slice op should not be null.");
3131
PADDLE_ENFORCE(ctx->HasOutput("Out"),
@@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel {
5858

5959
protected:
6060
framework::OpKernelType GetExpectedKernelType(
61-
const framework::ExecutionContext &ctx) const override {
61+
const framework::ExecutionContext& ctx) const override {
6262
return framework::OpKernelType(
6363
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
6464
ctx.GetPlace());
@@ -87,13 +87,13 @@ Slice Operator.
8787
8888
Produces a slice of the input tensor along multiple axes. Similar to numpy:
8989
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
90-
Slice uses `axes`, `starts` and `ends` attributes to specify the start and
90+
Slice uses `axes`, `starts` and `ends` attributes to specify the start and
9191
end dimension for each axis in the list of axes, it uses this information
92-
to slice the input data tensor. If a negative value is passed for any of
93-
the start or end indices, it represents number of elements before the end
92+
to slice the input data tensor. If a negative value is passed for any of
93+
the start or end indices, it represents number of elements before the end
9494
of that dimension. If the value passed to start or end is larger than
95-
the n (the number of elements in this dimension), it represents n.
96-
For slicing to the end of a dimension with unknown size, it is recommended
95+
the n (the number of elements in this dimension), it represents n.
96+
For slicing to the end of a dimension with unknown size, it is recommended
9797
to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1].
9898
Following examples will explain how slice works:
9999
@@ -119,15 +119,54 @@ Following examples will explain how slice works:
119119
}
120120
};
121121

122+
class SliceOpGrad : public framework::OperatorWithKernel {
123+
public:
124+
using framework::OperatorWithKernel::OperatorWithKernel;
125+
126+
void InferShape(framework::InferShapeContext* ctx) const override {
127+
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null");
128+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
129+
"Input(Out@GRAD) should not be null");
130+
auto x_dims = ctx->GetInputDim("Input");
131+
auto x_grad_name = framework::GradVarName("Input");
132+
if (ctx->HasOutput(x_grad_name)) {
133+
ctx->SetOutputDim(x_grad_name, x_dims);
134+
}
135+
}
136+
};
137+
138+
class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
139+
public:
140+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
141+
142+
protected:
143+
std::unique_ptr<framework::OpDesc> Apply() const override {
144+
auto* bind = new framework::OpDesc();
145+
bind->SetInput("Input", Input("Input"));
146+
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
147+
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
148+
bind->SetAttrMap(Attrs());
149+
bind->SetType("slice_grad");
150+
return std::unique_ptr<framework::OpDesc>(bind);
151+
}
152+
};
153+
122154
} // namespace operators
123155
} // namespace paddle
124156

125157
namespace ops = paddle::operators;
126158
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
127-
paddle::framework::EmptyGradOpMaker);
159+
ops::SliceOpGradMaker);
160+
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad);
128161

129162
REGISTER_OP_CPU_KERNEL(
130163
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
131164
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
132165
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
133166
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);
167+
168+
REGISTER_OP_CPU_KERNEL(
169+
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
170+
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
171+
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
172+
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/slice_op.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ REGISTER_OP_CUDA_KERNEL(
2020
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
2121
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
2222
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>);
23+
24+
REGISTER_OP_CUDA_KERNEL(
25+
slice_grad,
26+
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
27+
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
28+
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
29+
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>);

paddle/fluid/operators/slice_op.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <algorithm>
17+
#include <utility>
1718
#include <vector>
1819
#include "paddle/fluid/framework/op_registry.h"
1920

@@ -84,5 +85,79 @@ class SliceKernel : public framework::OpKernel<T> {
8485
out_t.device(place) = in_t.slice(offsets, extents);
8586
}
8687
};
88+
89+
template <typename DeviceContext, typename T>
90+
class SliceGradKernel : public framework::OpKernel<T> {
91+
public:
92+
void Compute(const framework::ExecutionContext& ctx) const override {
93+
size_t rank = ctx.Input<framework::Tensor>(framework::GradVarName("Out"))
94+
->dims()
95+
.size();
96+
switch (rank) {
97+
case 1:
98+
SliceCompute<1>(ctx);
99+
break;
100+
case 2:
101+
SliceCompute<2>(ctx);
102+
break;
103+
case 3:
104+
SliceCompute<3>(ctx);
105+
break;
106+
case 4:
107+
SliceCompute<4>(ctx);
108+
break;
109+
case 5:
110+
SliceCompute<5>(ctx);
111+
break;
112+
case 6:
113+
SliceCompute<6>(ctx);
114+
break;
115+
}
116+
}
117+
118+
private:
119+
template <size_t D>
120+
void SliceCompute(const framework::ExecutionContext& context) const {
121+
auto& place =
122+
*context.template device_context<DeviceContext>().eigen_device();
123+
auto* d_out =
124+
context.Input<framework::Tensor>(framework::GradVarName("Out"));
125+
auto* d_input =
126+
context.Output<framework::Tensor>(framework::GradVarName("Input"));
127+
d_input->mutable_data<T>(context.GetPlace());
128+
auto out_dims = d_out->dims();
129+
auto in_dims = d_input->dims();
130+
auto axes = context.Attr<std::vector<int>>("axes");
131+
auto starts = context.Attr<std::vector<int>>("starts");
132+
133+
auto offsets = Eigen::array<int, D>();
134+
auto extents = Eigen::array<int, D>();
135+
for (size_t i = 0; i < D; ++i) {
136+
offsets[i] = 0;
137+
extents[i] = out_dims[i];
138+
}
139+
int start;
140+
for (size_t i = 0; i < axes.size(); ++i) {
141+
start = starts[i];
142+
if (start < 0) {
143+
start = (start + in_dims[axes[i]]);
144+
}
145+
start = std::max(start, 0);
146+
offsets[axes[i]] = start;
147+
}
148+
Eigen::array<std::pair<int, int>, D> paddings;
149+
for (size_t i = 0; i < paddings.size(); ++i) {
150+
paddings[i].first = offsets[i];
151+
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
152+
}
153+
auto d_in_t =
154+
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
155+
*d_input);
156+
auto d_out_t =
157+
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
158+
*d_out);
159+
d_in_t.device(place) = d_out_t.pad(paddings, 0);
160+
}
161+
};
87162
} // namespace operators
88163
} // namespace paddle

python/paddle/fluid/tests/unittests/test_slice_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def config(self):
4141
def test_check_output(self):
4242
self.check_output()
4343

44+
def test_check_grad_normal(self):
45+
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
46+
4447

4548
class TestCase1(TestSliceOp):
4649
def config(self):

0 commit comments

Comments
 (0)