Skip to content

Commit 59bf85d

Browse files
Merge pull request #7325 from kuke/sequence_erase_op
Add sequence erase op
2 parents da3087a + 1077946 commit 59bf85d

File tree

4 files changed

+327
-0
lines changed

4 files changed

+327
-0
lines changed

paddle/operators/sequence_erase_op.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/sequence_erase_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class SequenceEraseOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of SequenceEraseOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
28+
"Output(Out) of SequenceEraseOp should not be null.");
29+
auto x_dims = ctx->GetInputDim("X");
30+
PADDLE_ENFORCE(x_dims.size() == 2 && x_dims[1] == 1,
31+
"Input(X) of SequenceEraseOp should be a 2-D LoDTensor "
32+
"with the 2nd dimension equal to 1.");
33+
ctx->SetOutputDim("Out", x_dims);
34+
}
35+
};
36+
37+
class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker {
38+
public:
39+
SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
40+
: OpProtoAndCheckerMaker(proto, op_checker) {
41+
AddInput("X",
42+
"(2-D LoDTensor with the 2nd dim. equal to 1) "
43+
"Input LoDTensor of SequenceEraseOp.");
44+
AddOutput("Out",
45+
"(2-D LoDTensor with the 2nd dim. equal to 1) "
46+
"Output LoDTensor of SequenceEraseOp.");
47+
AddAttr<std::vector<int>>("tokens",
48+
"(vector<int>) Tokens need to be erased from "
49+
"input sequences.");
50+
AddComment(R"DOC(
51+
Sequence Erase Operator.
52+
53+
Sequence erase operator erases tokens specified by Attr(tokens) from the input
54+
sequences Input(X), and outputs the remaining data and modifies the LoD
55+
information at the same time. For example, given a 2-D LoDTensor
56+
57+
X = [[2, 2, 6, 1, 3, 9, 6, 1, 0, 1]]^T
58+
59+
with lod = [[0, 3, 6, 10]], there are three sequences in the input:
60+
61+
X1 = [[2, 2, 6]]^T, X2 = [[1, 3, 9]]^T and X3 = [[6, 1, 0, 1]]^T.
62+
63+
If the tokens to be erased are Attr(tokens) = [2, 3, 5], after the erasing
64+
operation, the three sequences become
65+
66+
X1' = [[6]]^T, X2' = [[1, 9]]^T and X3' = [[6, 1, 0, 1]]^T.
67+
68+
Hence the LoDTensor Output(Out) should be
69+
70+
Out = [[6, 1, 9, 6, 1, 0, 1]]^T,
71+
72+
with lod = [[0, 1, 3, 7]].
73+
74+
An example usage for this operator is to remove the special tokens when
75+
computing the edit distance between two strings, such as blank, start token,
76+
and end token.
77+
)DOC");
78+
}
79+
};
80+
81+
} // namespace operators
82+
} // namespace paddle
83+
84+
namespace ops = paddle::operators;
85+
REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
86+
ops::SequenceEraseOpMaker);
87+
REGISTER_OP_CPU_KERNEL(
88+
sequence_erase,
89+
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>);

paddle/operators/sequence_erase_op.cu

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <thrust/device_vector.h>
16+
#include <thrust/host_vector.h>
17+
#include "paddle/operators/sequence_erase_op.h"
18+
#include "paddle/platform/cuda_helper.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
using platform::PADDLE_CUDA_NUM_THREADS;
23+
using LoDTensor = framework::LoDTensor;
24+
25+
template <typename T>
26+
__global__ void LabelErasedIdx(const T* in_dat, const int in_len,
27+
const T* tokens, const int tokens_len,
28+
int* num_erased) {
29+
int index = blockIdx.x * blockDim.x + threadIdx.x;
30+
if (index < in_len) {
31+
int erased = 0;
32+
for (int i = 0; i < tokens_len; ++i) {
33+
if (in_dat[index] == tokens[i]) {
34+
erased = 1;
35+
}
36+
}
37+
num_erased[index + 1] = erased;
38+
if (index == 0) {
39+
num_erased[0] = 0;
40+
}
41+
}
42+
}
43+
44+
template <typename T>
45+
__global__ void GetOutLod(const T* num_erased, const int* in_lod,
46+
const int lod_len, int* out_lod0) {
47+
int index = blockIdx.x * blockDim.x + threadIdx.x;
48+
if (index < lod_len) {
49+
out_lod0[index] = in_lod[index] - num_erased[in_lod[index]];
50+
}
51+
}
52+
53+
template <typename T>
54+
__global__ void SetOutput(const T* in_dat, const int in_len,
55+
const int* num_erased, T* out_dat) {
56+
int index = blockIdx.x * blockDim.x + threadIdx.x;
57+
if (index < in_len) {
58+
if (in_dat[index] != in_dat[index + 1]) {
59+
out_dat[index - num_erased[index]] = in_dat[index];
60+
}
61+
}
62+
}
63+
64+
template <typename T>
65+
class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
66+
public:
67+
void Compute(const framework::ExecutionContext& ctx) const override {
68+
auto* in = ctx.Input<LoDTensor>("X");
69+
auto* out = ctx.Output<LoDTensor>("Out");
70+
71+
auto lod = in->lod();
72+
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
73+
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
74+
"The actual size mismatches with the LoD information.");
75+
auto tokens = ctx.Attr<std::vector<T>>("tokens");
76+
auto tokens_len = tokens.size();
77+
auto in_len = in->numel();
78+
auto in_dat = in->data<T>();
79+
auto lod0 = lod[0];
80+
81+
thrust::host_vector<T> host_tokens(tokens_len);
82+
for (size_t i = 0; i < tokens.size(); ++i) {
83+
host_tokens[i] = tokens[i];
84+
}
85+
thrust::device_vector<T> dev_tokens = host_tokens;
86+
thrust::device_vector<int> num_erased(in_len + 1);
87+
88+
T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
89+
int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
90+
91+
auto stream = ctx.cuda_device_context().stream();
92+
LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
93+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
94+
in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr);
95+
thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(),
96+
num_erased.begin() + 1);
97+
98+
// Calc LoD
99+
auto lod_len = lod0.size();
100+
thrust::host_vector<int> host_lod(lod_len);
101+
for (size_t i = 0; i < lod_len; ++i) {
102+
host_lod[i] = lod0[i];
103+
}
104+
thrust::device_vector<int> dev_in_lod = host_lod;
105+
thrust::device_vector<int> dev_out_lod(lod_len);
106+
int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
107+
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
108+
GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
109+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
110+
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
111+
thrust::host_vector<int> host_out_lod = dev_out_lod;
112+
std::vector<int> out_lod0(lod_len, 0);
113+
for (size_t i = 0; i < lod_len; i++) {
114+
out_lod0[i] = host_out_lod[i];
115+
}
116+
framework::LoD out_lod;
117+
out_lod.push_back(out_lod0);
118+
out->set_lod(out_lod);
119+
120+
// Set output
121+
out->Resize({out_lod0.back(), 1});
122+
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
123+
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
124+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
125+
num_erased_ptr, out_dat);
126+
}
127+
};
128+
129+
} // namespace operators
130+
} // namespace paddle
131+
132+
REGISTER_OP_CUDA_KERNEL(sequence_erase,
133+
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>);

paddle/operators/sequence_erase_op.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename DeviceContext, typename T>
23+
class SequenceEraseKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
auto* in = ctx.Input<framework::LoDTensor>("X");
27+
auto* out = ctx.Output<framework::LoDTensor>("Out");
28+
29+
auto lod = in->lod();
30+
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
31+
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
32+
"The actual size mismatches with the LoD information.");
33+
auto tokens = ctx.Attr<std::vector<int>>("tokens");
34+
auto in_len = in->numel();
35+
auto in_dat = in->data<T>();
36+
auto lod0 = lod[0];
37+
38+
std::vector<size_t> num_erased(in_len + 1, 0);
39+
std::vector<size_t> out_lod0(1, 0);
40+
for (size_t i = 0; i < lod0.size() - 1; ++i) {
41+
size_t num_out = 0;
42+
for (auto j = lod0[i] + 1; j <= lod0[i + 1]; ++j) {
43+
num_erased[j] = num_erased[j - 1];
44+
if (std::find(tokens.begin(), tokens.end(), in_dat[j - 1]) !=
45+
tokens.end()) {
46+
num_erased[j] += 1;
47+
} else {
48+
num_out += 1;
49+
}
50+
}
51+
out_lod0.push_back(out_lod0.back() + num_out);
52+
}
53+
54+
auto out_len = in_len - num_erased[in_len];
55+
out->Resize({static_cast<int64_t>(out_len), 1});
56+
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
57+
58+
for (int64_t i = 0; i < in_len; ++i) {
59+
if (num_erased[i] == num_erased[i + 1]) {
60+
out_dat[i - num_erased[i]] = in_dat[i];
61+
}
62+
}
63+
framework::LoD out_lod;
64+
out_lod.push_back(out_lod0);
65+
out->set_lod(out_lod);
66+
}
67+
};
68+
69+
} // namespace operators
70+
} // namespace paddle
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
import numpy as np
3+
from op_test import OpTest
4+
5+
6+
def sequence_erase(in_seq, lod0, tokens):
7+
new_lod0 = [0]
8+
out_seq = []
9+
for i in range(0, len(lod0) - 1):
10+
num_out = 0
11+
for dat in in_seq[lod0[i]:lod0[i + 1]]:
12+
if dat not in tokens:
13+
out_seq.append(dat)
14+
num_out += 1
15+
new_lod0.append(new_lod0[-1] + num_out)
16+
return np.array(out_seq).astype("int32"), new_lod0
17+
18+
19+
class TestSequenceEraseOp(OpTest):
20+
def setUp(self):
21+
self.op_type = "sequence_erase"
22+
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
23+
lod = [[0, 9, 13, 24, 30]]
24+
tokens = [2, 3, 5]
25+
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
26+
self.attrs = {'tokens': tokens}
27+
self.inputs = {'X': (in_seq, lod)}
28+
self.outputs = {'Out': (out_seq, [new_lod0])}
29+
30+
def test_check_output(self):
31+
self.check_output()
32+
33+
34+
if __name__ == '__main__':
35+
unittest.main()

0 commit comments

Comments
 (0)