Skip to content

Commit 3b6090e

Browse files
authored
Merge pull request #12887 from chenwhql/sequence_enumerate_op
Feat: add sequence enumerate op
2 parents 4529f70 + 7ddbbcb commit 3b6090e

File tree

7 files changed

+397
-0
lines changed

7 files changed

+397
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'],
172172
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
173173
paddle.fluid.layers.pad2d ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None))
174174
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
175+
paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None))
175176
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
176177
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
177178
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/sequence_enumerate_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class SequenceEnumerateOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(
26+
ctx->HasInput("X"),
27+
"Input(X) of SequecceEnumerate operator should not be null.");
28+
PADDLE_ENFORCE(
29+
ctx->HasOutput("Out"),
30+
"Output(X) of SequenceEnumerate operator should not be null.");
31+
32+
const auto x_dims = ctx->GetInputDim("X");
33+
PADDLE_ENFORCE_EQ(
34+
x_dims.size(), 2UL,
35+
"Input(X) of SequenceEnumerate operator's rank should be 2.");
36+
PADDLE_ENFORCE_EQ(
37+
x_dims[1], 1UL,
38+
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1.");
39+
40+
const auto win_size = ctx->Attrs().Get<int>("win_size");
41+
ctx->SetOutputDim("Out", {x_dims[0], win_size});
42+
ctx->ShareLoD("X", "Out");
43+
}
44+
};
45+
46+
class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
47+
public:
48+
void Make() override {
49+
AddInput("X",
50+
"(2-D LoDTensor with the 2nd dimension equal to 1) "
51+
"Input LoDTensor of SequenceEnumerate operator.");
52+
AddOutput("Out",
53+
"(2-D LoDTensor with the 2nd dimension equal to win_size) "
54+
"Output LoDTensor of SequenceEnumerate operator.");
55+
AddAttr<int>("win_size", "(int) The enumerate sequence window size.")
56+
.AddCustomChecker([](const int& win_size) {
57+
PADDLE_ENFORCE(win_size >= 2,
58+
"The window size should be not less than 2.");
59+
});
60+
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
61+
.SetDefault(0);
62+
AddComment(R"DOC(
63+
Sequence Enumerate Operator.
64+
65+
Generate a new sequence for the input index sequence, which enumerates all the
66+
sub-sequences with length `win_size` of the input.
67+
The enumerated sequence has the same 1st dimension with variable `input`, and
68+
the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation.
69+
70+
Examples:
71+
Case 1:
72+
Input:
73+
X.lod = [[0, 3, 5]]
74+
X.data = [[1], [2], [3], [4], [5]]
75+
X.dims = [5, 1]
76+
Attrs:
77+
win_size = 2
78+
pad_value = 0
79+
Output:
80+
Out.lod = [[0, 3, 5]]
81+
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
82+
Out.dims = [5, 2]
83+
84+
)DOC");
85+
}
86+
};
87+
88+
} // namespace operators
89+
} // namespace paddle
90+
91+
namespace ops = paddle::operators;
92+
REGISTER_OP_WITHOUT_GRADIENT(sequence_enumerate, ops::SequenceEnumerateOp,
93+
ops::SequenceEnumerateOpMaker);
94+
REGISTER_OP_CPU_KERNEL(
95+
sequence_enumerate,
96+
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int32_t>,
97+
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int64_t>);
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <thrust/device_vector.h>
16+
#include <thrust/host_vector.h>
17+
#include "paddle/fluid/operators/sequence_enumerate_op.h"
18+
#include "paddle/fluid/platform/cuda_primitives.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
using platform::PADDLE_CUDA_NUM_THREADS;
23+
using LoDTensor = framework::LoDTensor;
24+
25+
template <typename T>
26+
__global__ void CalcOutPut(const T* in_data, const size_t* in_lod,
27+
const size_t lod_len, const int64_t win_size,
28+
const int64_t pad_value, T* out_data) {
29+
int index = blockIdx.x * blockDim.x + threadIdx.x;
30+
if (index < in_lod[lod_len - 1]) {
31+
int end_idx = 0;
32+
// Get LoD interval of index
33+
for (int i = 1; i < lod_len; ++i) {
34+
if (index < in_lod[i]) {
35+
end_idx = in_lod[i];
36+
break;
37+
}
38+
}
39+
for (size_t i = 0; i < win_size; ++i) {
40+
int word_pos = index + i;
41+
out_data[index * win_size + i] =
42+
word_pos < end_idx ? in_data[word_pos] : pad_value;
43+
}
44+
}
45+
}
46+
47+
template <typename T>
48+
class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
49+
public:
50+
void Compute(const framework::ExecutionContext& context) const override {
51+
auto* in = context.Input<LoDTensor>("X");
52+
auto* out = context.Output<LoDTensor>("Out");
53+
int win_size = context.Attr<int>("win_size");
54+
int pad_value = context.Attr<int>("pad_value");
55+
56+
auto in_dims = in->dims();
57+
auto in_lod = in->lod();
58+
59+
PADDLE_ENFORCE_EQ(
60+
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
61+
"The actual input data's size mismatched with LoD information.");
62+
63+
/* Generate enumerate sequence set */
64+
auto stream = context.cuda_device_context().stream();
65+
auto lod0 = in_lod[0];
66+
auto in_len = in->numel();
67+
auto in_data = in->data<T>();
68+
auto out_data = out->mutable_data<T>(context.GetPlace());
69+
// Copy LoD to GPU
70+
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
71+
// Calc output tensor
72+
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
73+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
74+
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
75+
}
76+
};
77+
78+
} // namespace operators
79+
} // namespace paddle
80+
81+
REGISTER_OP_CUDA_KERNEL(
82+
sequence_enumerate,
83+
paddle::operators::SequenceEnumerateOpCUDAKernel<int32_t>,
84+
paddle::operators::SequenceEnumerateOpCUDAKernel<int64_t>);
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
using LoDTensor = framework::LoDTensor;
22+
23+
template <typename DeviceContext, typename T>
24+
class SequenceEnumerateKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& context) const override {
27+
auto* in = context.Input<LoDTensor>("X");
28+
auto* out = context.Output<LoDTensor>("Out");
29+
int win_size = context.Attr<int>("win_size");
30+
int pad_value = context.Attr<int>("pad_value");
31+
32+
auto in_dims = in->dims();
33+
auto in_lod = in->lod();
34+
35+
PADDLE_ENFORCE_EQ(
36+
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
37+
"The actual input data's size mismatched with LoD information.");
38+
39+
// Generate enumerate sequence set
40+
auto lod0 = in_lod[0];
41+
auto in_data = in->data<T>();
42+
auto out_data = out->mutable_data<T>(context.GetPlace());
43+
for (size_t i = 0; i < lod0.size() - 1; ++i) {
44+
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
45+
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
46+
size_t word_pos = idx + word_idx;
47+
out_data[win_size * idx + word_idx] =
48+
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value;
49+
}
50+
}
51+
}
52+
}
53+
};
54+
55+
} // namespace operators
56+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
'stack',
112112
'pad2d',
113113
'unstack',
114+
'sequence_enumerate',
114115
]
115116

116117

@@ -5823,6 +5824,51 @@ def flatten(x, axis=1, name=None):
58235824
return out
58245825

58255826

5827+
def sequence_enumerate(input, win_size, pad_value=0, name=None):
5828+
"""
5829+
Generate a new sequence for the input index sequence, which enumerates all the
5830+
sub-sequences with length `win_size` of the input.
5831+
The enumerated sequence has the same 1st dimension with variable `input`, and
5832+
the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation.
5833+
5834+
Examples:
5835+
Case 1:
5836+
Input:
5837+
X.lod = [[0, 3, 5]]
5838+
X.data = [[1], [2], [3], [4], [5]]
5839+
X.dims = [5, 1]
5840+
Attrs:
5841+
win_size = 2
5842+
pad_value = 0
5843+
Output:
5844+
Out.lod = [[0, 3, 5]]
5845+
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
5846+
Out.dims = [5, 2]
5847+
5848+
Args:
5849+
input (Variable): The input variable which is a index sequence.
5850+
win_size (int): The window size for enumerating all sub-sequences.
5851+
pad_value (int): The padding value, default 0.
5852+
5853+
Returns:
5854+
Variable: The enumerate sequence variable which is a LoDTensor.
5855+
5856+
Examples:
5857+
.. code-block:: python
5858+
5859+
x = fluid.layers.data(shape[30, 1], dtype='int32', lod_level=1)
5860+
out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0)
5861+
"""
5862+
helper = LayerHelper('sequence_enumerate', **locals())
5863+
out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True)
5864+
helper.append_op(
5865+
type='sequence_enumerate',
5866+
inputs={'X': input},
5867+
outputs={'Out': out},
5868+
attrs={'win_size': win_size,
5869+
'pad_value': pad_value})
5870+
5871+
58265872
def sequence_mask(x, maxlen=None, dtype='int64', name=None):
58275873
"""
58285874
**SequenceMask Layer**
@@ -5902,6 +5948,7 @@ def stack(x, axis=0):
59025948
helper.append_op(
59035949
type='stack', inputs={'X': x}, outputs={'Y': out},
59045950
attrs={'axis': axis})
5951+
59055952
return out
59065953

59075954

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,13 @@ def test_prelu(self):
549549
self.assertIsNotNone(out)
550550
print(str(program))
551551

552+
def test_sequence_enumerate(self):
553+
program = Program()
554+
with program_guard(program):
555+
x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1)
556+
out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0)
557+
print(str(program))
558+
552559

553560
if __name__ == '__main__':
554561
unittest.main()

0 commit comments

Comments
 (0)