Skip to content

Commit adc0908

Browse files
Add slice op. (#11052)
* Add slice op. * Remove using from header file and fix doc. * Fix doc * Small fix.
1 parent 1d19849 commit adc0908

File tree

5 files changed

+303
-0
lines changed

5 files changed

+303
-0
lines changed

paddle/fluid/operators/slice_op.cc

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/slice_op.h"
16+
#include <algorithm>
17+
#include <vector>
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
24+
class SliceOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
28+
void InferShape(framework::InferShapeContext *ctx) const override {
29+
PADDLE_ENFORCE(ctx->HasInput("Input"),
30+
"Input (Input) of slice op should not be null.");
31+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
32+
"Output (Out) of slice op should not be null.");
33+
34+
auto in_dims = ctx->GetInputDim("Input");
35+
PADDLE_ENFORCE(in_dims.size() < 7,
36+
"The rank of input should be less than 7.");
37+
framework::DDim out_dims(in_dims);
38+
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
39+
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
40+
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
41+
42+
PADDLE_ENFORCE_EQ(starts.size(), ends.size());
43+
PADDLE_ENFORCE_EQ(starts.size(), axes.size());
44+
int dim_value, start, end;
45+
for (size_t i = 0; i < axes.size(); ++i) {
46+
dim_value = out_dims[axes[i]];
47+
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
48+
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
49+
start = std::max(start, 0);
50+
end = std::max(end, 0);
51+
start = std::min(start, dim_value);
52+
end = std::min(end, dim_value);
53+
start = std::min(start, end);
54+
out_dims[axes[i]] = end - start;
55+
}
56+
ctx->SetOutputDim("Out", out_dims);
57+
}
58+
59+
protected:
60+
framework::OpKernelType GetExpectedKernelType(
61+
const framework::ExecutionContext &ctx) const override {
62+
return framework::OpKernelType(
63+
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
64+
ctx.GetPlace());
65+
}
66+
};
67+
68+
class SliceOpMaker : public framework::OpProtoAndCheckerMaker {
69+
public:
70+
void Make() override {
71+
AddInput("Input", "Tensor of data to extract slices from.");
72+
AddOutput("Out", "Sliced data tensor.");
73+
74+
AddAttr<std::vector<int>>(
75+
"axes",
76+
"(list<int>) Axes that `starts` and `ends` apply to. It's optional."
77+
"If not present, will be treated as [0, 1, ..., len(`starts`) - 1].");
78+
AddAttr<std::vector<int>>(
79+
"starts",
80+
"(list<int>) Starting indices of corresponding axis in `axes`");
81+
AddAttr<std::vector<int>>(
82+
"ends",
83+
"(list<int>) Starting indices of corresponding axis in `axes`.");
84+
85+
AddComment(R"DOC(
86+
Slice Operator.
87+
88+
Produces a slice of the input tensor along multiple axes. Similar to numpy:
89+
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
90+
Slice uses `axes`, `starts` and `ends` attributes to specify the start and
91+
end dimension for each axis in the list of axes, it uses this information
92+
to slice the input data tensor. If a negative value is passed for any of
93+
the start or end indices, it represents number of elements before the end
94+
of that dimension. If the value passed to start or end is larger than
95+
the n (the number of elements in this dimension), it represents n.
96+
For slicing to the end of a dimension with unknown size, it is recommended
97+
to pass in INT_MAX. If axes are omitted, they are set to [0, ..., ndim-1].
98+
99+
Example 1:
100+
Given:
101+
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
102+
axes = [0, 1]
103+
starts = [1, 0]
104+
ends = [2, 3]
105+
Then:
106+
result = [ [5, 6, 7], ]
107+
108+
Example 2:
109+
Given:
110+
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
111+
starts = [0, 1]
112+
ends = [-1, 1000]
113+
Then:
114+
result = [ [2, 3, 4], ]
115+
)DOC");
116+
}
117+
};
118+
119+
} // namespace operators
120+
} // namespace paddle
121+
122+
namespace ops = paddle::operators;
123+
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
124+
paddle::framework::EmptyGradOpMaker);
125+
126+
REGISTER_OP_CPU_KERNEL(
127+
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
128+
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
129+
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
130+
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/slice_op.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/slice_op.h"
16+
17+
namespace ops = paddle::operators;
18+
REGISTER_OP_CUDA_KERNEL(
19+
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
20+
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
21+
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
22+
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>);

paddle/fluid/operators/slice_op.h

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <algorithm>
17+
#include <vector>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename DeviceContext, typename T>
24+
class SliceKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const override {
27+
int rank = ctx.Input<framework::Tensor>("Input")->dims().size();
28+
switch (rank) {
29+
case 1:
30+
SliceCompute<1>(ctx);
31+
break;
32+
case 2:
33+
SliceCompute<2>(ctx);
34+
break;
35+
case 3:
36+
SliceCompute<3>(ctx);
37+
break;
38+
case 4:
39+
SliceCompute<4>(ctx);
40+
break;
41+
case 5:
42+
SliceCompute<5>(ctx);
43+
break;
44+
case 6:
45+
SliceCompute<6>(ctx);
46+
break;
47+
}
48+
}
49+
50+
private:
51+
template <size_t D>
52+
void SliceCompute(const framework::ExecutionContext& context) const {
53+
auto& place =
54+
*context.template device_context<DeviceContext>().eigen_device();
55+
auto in = context.Input<framework::Tensor>("Input");
56+
auto out = context.Output<framework::Tensor>("Out");
57+
out->mutable_data<T>(context.GetPlace());
58+
auto out_dims = out->dims();
59+
auto in_dims = in->dims();
60+
auto axes = context.Attr<std::vector<int>>("axes");
61+
auto starts = context.Attr<std::vector<int>>("starts");
62+
63+
auto offsets = Eigen::array<int, D>();
64+
auto extents = Eigen::array<int, D>();
65+
for (size_t i = 0; i < D; ++i) {
66+
offsets[i] = 0;
67+
extents[i] = out_dims[i];
68+
}
69+
int start;
70+
for (size_t i = 0; i < axes.size(); ++i) {
71+
start = starts[i];
72+
if (start < 0) {
73+
start = (start + in_dims[axes[i]]);
74+
}
75+
start = std::max(start, 0);
76+
offsets[axes[i]] = start;
77+
}
78+
auto in_t =
79+
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
80+
*in);
81+
auto out_t =
82+
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
83+
*out);
84+
out_t.device(place) = in_t.slice(offsets, extents);
85+
}
86+
};
87+
} // namespace operators
88+
} // namespace paddle

python/paddle/fluid/layers/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
'cumsum',
7272
'scatter',
7373
'sum',
74+
'slice',
7475
'polygon_box_transform',
7576
'shape',
7677
'maxout',
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
20+
class TestSliceOp(OpTest):
21+
def setUp(self):
22+
self.op_type = "slice"
23+
self.config()
24+
self.inputs = {'Input': self.input}
25+
self.outputs = {'Out': self.out}
26+
self.attrs = {
27+
'axes': self.axes,
28+
'starts': self.starts,
29+
'ends': self.ends
30+
}
31+
32+
def config(self):
33+
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
34+
self.starts = [1, 0, 2]
35+
self.ends = [3, 3, 4]
36+
self.axes = [0, 1, 2]
37+
self.out = self.input[1:3, 0:3, 2:4, :]
38+
39+
def test_check_output(self):
40+
self.check_output()
41+
42+
43+
class TestCase1(TestSliceOp):
44+
def config(self):
45+
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
46+
self.starts = [-3, 0, 2]
47+
self.ends = [3, 100, -1]
48+
self.axes = [0, 1, 2]
49+
self.out = self.input[-3:3, 0:100, 2:-1, :]
50+
51+
52+
class TestCase2(TestSliceOp):
53+
def config(self):
54+
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
55+
self.starts = [-3, 0, 2]
56+
self.ends = [3, 100, -1]
57+
self.axes = [0, 1, 3]
58+
self.out = self.input[-3:3, 0:100, :, 2:-1]
59+
60+
61+
if __name__ == '__main__':
62+
unittest.main()

0 commit comments

Comments
 (0)