@@ -25,7 +25,7 @@ class SliceOp : public framework::OperatorWithKernel {
25
25
public:
26
26
using framework::OperatorWithKernel::OperatorWithKernel;
27
27
28
- void InferShape (framework::InferShapeContext * ctx) const override {
28
+ void InferShape (framework::InferShapeContext* ctx) const override {
29
29
PADDLE_ENFORCE (ctx->HasInput (" Input" ),
30
30
" Input (Input) of slice op should not be null." );
31
31
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
@@ -58,7 +58,7 @@ class SliceOp : public framework::OperatorWithKernel {
58
58
59
59
protected:
60
60
framework::OpKernelType GetExpectedKernelType (
61
- const framework::ExecutionContext & ctx) const override {
61
+ const framework::ExecutionContext& ctx) const override {
62
62
return framework::OpKernelType (
63
63
framework::ToDataType (ctx.Input <Tensor>(" Input" )->type ()),
64
64
ctx.GetPlace ());
@@ -87,13 +87,13 @@ Slice Operator.
87
87
88
88
Produces a slice of the input tensor along multiple axes. Similar to numpy:
89
89
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
90
- Slice uses `axes`, `starts` and `ends` attributes to specify the start and
90
+ Slice uses `axes`, `starts` and `ends` attributes to specify the start and
91
91
end dimension for each axis in the list of axes, it uses this information
92
- to slice the input data tensor. If a negative value is passed for any of
93
- the start or end indices, it represents number of elements before the end
92
+ to slice the input data tensor. If a negative value is passed for any of
93
+ the start or end indices, it represents number of elements before the end
94
94
of that dimension. If the value passed to start or end is larger than
95
- the n (the number of elements in this dimension), it represents n.
96
- For slicing to the end of a dimension with unknown size, it is recommended
95
+ the n (the number of elements in this dimension), it represents n.
96
+ For slicing to the end of a dimension with unknown size, it is recommended
97
97
to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1].
98
98
Following examples will explain how slice works:
99
99
@@ -119,15 +119,54 @@ Following examples will explain how slice works:
119
119
}
120
120
};
121
121
122
+ class SliceOpGrad : public framework ::OperatorWithKernel {
123
+ public:
124
+ using framework::OperatorWithKernel::OperatorWithKernel;
125
+
126
+ void InferShape (framework::InferShapeContext* ctx) const override {
127
+ PADDLE_ENFORCE (ctx->HasInput (" Input" ), " Input should not be null" );
128
+ PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
129
+ " Input(Out@GRAD) should not be null" );
130
+ auto x_dims = ctx->GetInputDim (" Input" );
131
+ auto x_grad_name = framework::GradVarName (" Input" );
132
+ if (ctx->HasOutput (x_grad_name)) {
133
+ ctx->SetOutputDim (x_grad_name, x_dims);
134
+ }
135
+ }
136
+ };
137
+
138
+ class SliceOpGradMaker : public framework ::SingleGradOpDescMaker {
139
+ public:
140
+ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
141
+
142
+ protected:
143
+ std::unique_ptr<framework::OpDesc> Apply () const override {
144
+ auto * bind = new framework::OpDesc ();
145
+ bind->SetInput (" Input" , Input (" Input" ));
146
+ bind->SetInput (framework::GradVarName (" Out" ), OutputGrad (" Out" ));
147
+ bind->SetOutput (framework::GradVarName (" Input" ), InputGrad (" Input" ));
148
+ bind->SetAttrMap (Attrs ());
149
+ bind->SetType (" slice_grad" );
150
+ return std::unique_ptr<framework::OpDesc>(bind);
151
+ }
152
+ };
153
+
122
154
} // namespace operators
123
155
} // namespace paddle
124
156
125
157
namespace ops = paddle::operators;
126
158
REGISTER_OPERATOR (slice, ops::SliceOp, ops::SliceOpMaker,
127
- paddle::framework::EmptyGradOpMaker);
159
+ ops::SliceOpGradMaker);
160
+ REGISTER_OPERATOR (slice_grad, ops::SliceOpGrad);
128
161
129
162
REGISTER_OP_CPU_KERNEL (
130
163
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int >,
131
164
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t >,
132
165
ops::SliceKernel<paddle::platform::CPUDeviceContext, float >,
133
166
ops::SliceKernel<paddle::platform::CPUDeviceContext, double >);
167
+
168
+ REGISTER_OP_CPU_KERNEL (
169
+ slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int >,
170
+ ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t >,
171
+ ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float >,
172
+ ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double >);
0 commit comments