Skip to content

Commit a1c281f

Browse files
author
Yibing Liu
authored
Merge pull request #7603 from kuke/simplify_erase
Enhance GPU kernel of sequence erase op
2 parents 41b8388 + 8809d43 commit a1c281f

File tree

3 files changed

+66
-48
lines changed

3 files changed

+66
-48
lines changed

paddle/operators/sequence_erase_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
8686
ops::SequenceEraseOpMaker);
8787
REGISTER_OP_CPU_KERNEL(
8888
sequence_erase,
89-
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>);
89+
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>,
90+
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/operators/sequence_erase_op.cu

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,34 @@ using platform::PADDLE_CUDA_NUM_THREADS;
2323
using LoDTensor = framework::LoDTensor;
2424

2525
template <typename T>
26-
__global__ void LabelErasedIdx(const T* in_dat, const int in_len,
27-
const T* tokens, const int tokens_len,
28-
int* num_erased) {
26+
__global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
27+
const int* tokens, const size_t tokens_len,
28+
size_t* num_erased) {
2929
int index = blockIdx.x * blockDim.x + threadIdx.x;
3030
if (index < in_len) {
31-
int erased = 0;
32-
for (int i = 0; i < tokens_len; ++i) {
31+
for (size_t i = 0; i < tokens_len; ++i) {
3332
if (in_dat[index] == tokens[i]) {
34-
erased = 1;
33+
num_erased[index + 1] = 1;
34+
break;
3535
}
3636
}
37-
num_erased[index + 1] = erased;
38-
if (index == 0) {
39-
num_erased[0] = 0;
40-
}
4137
}
4238
}
4339

44-
template <typename T>
45-
__global__ void GetOutLod(const T* num_erased, const int* in_lod,
46-
const int lod_len, int* out_lod0) {
40+
__global__ void GetOutLod(const size_t* num_erased, const size_t* in_lod,
41+
const size_t lod_len, size_t* out_lod0) {
4742
int index = blockIdx.x * blockDim.x + threadIdx.x;
4843
if (index < lod_len) {
4944
out_lod0[index] = in_lod[index] - num_erased[in_lod[index]];
5045
}
5146
}
5247

5348
template <typename T>
54-
__global__ void SetOutput(const T* in_dat, const int in_len,
55-
const int* num_erased, T* out_dat) {
49+
__global__ void SetOutput(const T* in_dat, const int64_t in_len,
50+
const size_t* num_erased, T* out_dat) {
5651
int index = blockIdx.x * blockDim.x + threadIdx.x;
5752
if (index < in_len) {
58-
if (in_dat[index] != in_dat[index + 1]) {
53+
if (num_erased[index] == num_erased[index + 1]) {
5954
out_dat[index - num_erased[index]] = in_dat[index];
6055
}
6156
}
@@ -72,53 +67,44 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
7267
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
7368
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
7469
"The actual size mismatches with the LoD information.");
75-
auto tokens = ctx.Attr<std::vector<T>>("tokens");
76-
auto tokens_len = tokens.size();
70+
auto tokens = ctx.Attr<std::vector<int>>("tokens");
7771
auto in_len = in->numel();
7872
auto in_dat = in->data<T>();
79-
auto lod0 = lod[0];
80-
81-
thrust::host_vector<T> host_tokens(tokens_len);
82-
for (size_t i = 0; i < tokens.size(); ++i) {
83-
host_tokens[i] = tokens[i];
84-
}
85-
thrust::device_vector<T> dev_tokens = host_tokens;
86-
thrust::device_vector<int> num_erased(in_len + 1);
87-
88-
T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
89-
int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
73+
// Copy tokens to GPU
74+
thrust::device_vector<int> dev_tokens(tokens.begin(), tokens.end());
75+
int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
9076

77+
// Count number of elements to be erased
78+
thrust::device_vector<size_t> num_erased(in_len + 1, 0);
79+
size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
9180
auto stream = ctx.cuda_device_context().stream();
9281
LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
9382
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
94-
in_dat, in_len, dev_tokens_ptr, tokens_len, num_erased_ptr);
83+
in_dat, in_len, dev_tokens_ptr, tokens.size(), num_erased_ptr);
9584
thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(),
9685
num_erased.begin() + 1);
9786

98-
// Calc LoD
87+
// Copy LoD to GPU
88+
auto lod0 = lod[0];
9989
auto lod_len = lod0.size();
100-
thrust::host_vector<int> host_lod(lod_len);
101-
for (size_t i = 0; i < lod_len; ++i) {
102-
host_lod[i] = lod0[i];
103-
}
104-
thrust::device_vector<int> dev_in_lod = host_lod;
105-
thrust::device_vector<int> dev_out_lod(lod_len);
106-
int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
107-
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
90+
thrust::device_vector<size_t> dev_in_lod = lod0;
91+
size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
92+
93+
// Calc output LoD
94+
thrust::device_vector<size_t> dev_out_lod(lod_len);
95+
size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
10896
GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
10997
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
11098
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
111-
thrust::host_vector<int> host_out_lod = dev_out_lod;
112-
std::vector<int> out_lod0(lod_len, 0);
113-
for (size_t i = 0; i < lod_len; i++) {
114-
out_lod0[i] = host_out_lod[i];
115-
}
99+
100+
// Set LoD for output
101+
thrust::host_vector<size_t> out_lod0 = dev_out_lod;
116102
framework::LoD out_lod;
117103
out_lod.push_back(out_lod0);
118104
out->set_lod(out_lod);
119105

120106
// Set output
121-
out->Resize({out_lod0.back(), 1});
107+
out->Resize({static_cast<int64_t>(out_lod0.back()), 1});
122108
auto out_dat = out->mutable_data<T>(ctx.GetPlace());
123109
SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
124110
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
@@ -130,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
130116
} // namespace paddle
131117

132118
REGISTER_OP_CUDA_KERNEL(sequence_erase,
133-
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>);
119+
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>,
120+
paddle::operators::SequenceEraseOpCUDAKernel<int64_t>);

python/paddle/v2/fluid/tests/test_sequence_erase_op.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def sequence_erase(in_seq, lod0, tokens):
2929
return np.array(out_seq).astype("int32"), new_lod0
3030

3131

32-
class TestSequenceEraseOp(OpTest):
32+
class TestSequenceEraseOpInt32(OpTest):
3333
def setUp(self):
3434
self.op_type = "sequence_erase"
3535
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
@@ -44,5 +44,35 @@ def test_check_output(self):
4444
self.check_output()
4545

4646

47+
class TestSequenceEraseOpInt64(OpTest):
48+
def setUp(self):
49+
self.op_type = "sequence_erase"
50+
in_seq = np.random.randint(0, 10, (30, 1)).astype("int64")
51+
lod = [[0, 9, 13, 24, 30]]
52+
tokens = [2, 3, 5]
53+
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
54+
self.attrs = {'tokens': tokens}
55+
self.inputs = {'X': (in_seq, lod)}
56+
self.outputs = {'Out': (out_seq, [new_lod0])}
57+
58+
def test_check_output(self):
59+
self.check_output()
60+
61+
62+
class TestSequenceEraseOpEmpty(OpTest):
63+
def setUp(self):
64+
self.op_type = "sequence_erase"
65+
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
66+
lod = [[0, 9, 13, 24, 30]]
67+
tokens = []
68+
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
69+
self.attrs = {'tokens': tokens}
70+
self.inputs = {'X': (in_seq, lod)}
71+
self.outputs = {'Out': (out_seq, [new_lod0])}
72+
73+
def test_check_output(self):
74+
self.check_output()
75+
76+
4777
if __name__ == '__main__':
4878
unittest.main()

0 commit comments

Comments
 (0)