Skip to content

Commit 8809d43

Browse files
author
Yibing Liu
committed
Remove unnecessary dtype conversion & register int64 kernels
1 parent 7a2aa48 commit 8809d43

File tree

3 files changed

+26
-35
lines changed

3 files changed

+26
-35
lines changed

paddle/operators/sequence_erase_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(sequence_erase, ops::SequenceEraseOp,
8686
ops::SequenceEraseOpMaker);
8787
REGISTER_OP_CPU_KERNEL(
8888
sequence_erase,
89-
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>);
89+
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int32_t>,
90+
ops::SequenceEraseKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/operators/sequence_erase_op.cu

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,12 @@ __global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
2828
size_t* num_erased) {
2929
int index = blockIdx.x * blockDim.x + threadIdx.x;
3030
if (index < in_len) {
31-
int erased = 0;
3231
for (size_t i = 0; i < tokens_len; ++i) {
3332
if (in_dat[index] == tokens[i]) {
34-
erased = 1;
33+
num_erased[index + 1] = 1;
34+
break;
3535
}
3636
}
37-
num_erased[index + 1] = erased;
38-
if (index == 0) {
39-
num_erased[0] = 0;
40-
}
4137
}
4238
}
4339

@@ -60,26 +56,6 @@ __global__ void SetOutput(const T* in_dat, const int64_t in_len,
6056
}
6157
}
6258

63-
template <typename T, typename Vector>
64-
thrust::device_vector<T> set_device_vector(Vector& vector) {
65-
thrust::host_vector<T> host_vec(vector.size());
66-
for (size_t i = 0; i < vector.size(); ++i) {
67-
host_vec[i] = vector[i];
68-
}
69-
thrust::device_vector<T> dev_vec = host_vec;
70-
return dev_vec;
71-
}
72-
73-
template <typename T>
74-
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) {
75-
thrust::host_vector<T> host_vec = dev_vec;
76-
std::vector<T> std_vec(host_vec.size(), 0);
77-
for (size_t i = 0; i < host_vec.size(); ++i) {
78-
std_vec[i] = host_vec[i];
79-
}
80-
return std_vec;
81-
}
82-
8359
template <typename T>
8460
class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
8561
public:
@@ -95,12 +71,11 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
9571
auto in_len = in->numel();
9672
auto in_dat = in->data<T>();
9773
// Copy tokens to GPU
98-
thrust::device_vector<int> dev_tokens =
99-
set_device_vector<int, std::vector<int>>(tokens);
74+
thrust::device_vector<int> dev_tokens(tokens.begin(), tokens.end());
10075
int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
10176

10277
// Count number of elements to be erased
103-
thrust::device_vector<size_t> num_erased(in_len + 1);
78+
thrust::device_vector<size_t> num_erased(in_len + 1, 0);
10479
size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
10580
auto stream = ctx.cuda_device_context().stream();
10681
LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
@@ -112,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
11287
// Copy LoD to GPU
11388
auto lod0 = lod[0];
11489
auto lod_len = lod0.size();
115-
thrust::device_vector<size_t> dev_in_lod =
116-
set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0);
90+
thrust::device_vector<size_t> dev_in_lod = lod0;
11791
size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());
11892

11993
// Calc output LoD
@@ -124,7 +98,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
12498
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
12599

126100
// Set LoD for output
127-
std::vector<size_t> out_lod0 = get_std_vector<size_t>(dev_out_lod);
101+
thrust::host_vector<size_t> out_lod0 = dev_out_lod;
128102
framework::LoD out_lod;
129103
out_lod.push_back(out_lod0);
130104
out->set_lod(out_lod);
@@ -142,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
142116
} // namespace paddle
143117

144118
REGISTER_OP_CUDA_KERNEL(sequence_erase,
145-
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>);
119+
paddle::operators::SequenceEraseOpCUDAKernel<int32_t>,
120+
paddle::operators::SequenceEraseOpCUDAKernel<int64_t>);

python/paddle/v2/fluid/tests/test_sequence_erase_op.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def sequence_erase(in_seq, lod0, tokens):
2929
return np.array(out_seq).astype("int32"), new_lod0
3030

3131

32-
class TestSequenceEraseOp(OpTest):
32+
class TestSequenceEraseOpInt32(OpTest):
3333
def setUp(self):
3434
self.op_type = "sequence_erase"
3535
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32")
@@ -44,6 +44,21 @@ def test_check_output(self):
4444
self.check_output()
4545

4646

47+
class TestSequenceEraseOpInt64(OpTest):
48+
def setUp(self):
49+
self.op_type = "sequence_erase"
50+
in_seq = np.random.randint(0, 10, (30, 1)).astype("int64")
51+
lod = [[0, 9, 13, 24, 30]]
52+
tokens = [2, 3, 5]
53+
out_seq, new_lod0 = sequence_erase(in_seq, lod[0], tokens)
54+
self.attrs = {'tokens': tokens}
55+
self.inputs = {'X': (in_seq, lod)}
56+
self.outputs = {'Out': (out_seq, [new_lod0])}
57+
58+
def test_check_output(self):
59+
self.check_output()
60+
61+
4762
class TestSequenceEraseOpEmpty(OpTest):
4863
def setUp(self):
4964
self.op_type = "sequence_erase"

0 commit comments

Comments
 (0)