@@ -28,16 +28,12 @@ __global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
28
28
size_t * num_erased) {
29
29
int index = blockIdx .x * blockDim .x + threadIdx .x ;
30
30
if (index < in_len) {
31
- int erased = 0 ;
32
31
for (size_t i = 0 ; i < tokens_len; ++i) {
33
32
if (in_dat[index] == tokens[i]) {
34
- erased = 1 ;
33
+ num_erased[index + 1 ] = 1 ;
34
+ break ;
35
35
}
36
36
}
37
- num_erased[index + 1 ] = erased;
38
- if (index == 0 ) {
39
- num_erased[0 ] = 0 ;
40
- }
41
37
}
42
38
}
43
39
@@ -60,26 +56,6 @@ __global__ void SetOutput(const T* in_dat, const int64_t in_len,
60
56
}
61
57
}
62
58
63
- template <typename T, typename Vector>
64
- thrust::device_vector<T> set_device_vector (Vector& vector) {
65
- thrust::host_vector<T> host_vec (vector.size ());
66
- for (size_t i = 0 ; i < vector.size (); ++i) {
67
- host_vec[i] = vector[i];
68
- }
69
- thrust::device_vector<T> dev_vec = host_vec;
70
- return dev_vec;
71
- }
72
-
73
- template <typename T>
74
- std::vector<T> get_std_vector (thrust::device_vector<T>& dev_vec) {
75
- thrust::host_vector<T> host_vec = dev_vec;
76
- std::vector<T> std_vec (host_vec.size (), 0 );
77
- for (size_t i = 0 ; i < host_vec.size (); ++i) {
78
- std_vec[i] = host_vec[i];
79
- }
80
- return std_vec;
81
- }
82
-
83
59
template <typename T>
84
60
class SequenceEraseOpCUDAKernel : public framework ::OpKernel<T> {
85
61
public:
@@ -95,12 +71,11 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
95
71
auto in_len = in->numel ();
96
72
auto in_dat = in->data <T>();
97
73
// Copy tokens to GPU
98
- thrust::device_vector<int > dev_tokens =
99
- set_device_vector<int , std::vector<int >>(tokens);
74
+ thrust::device_vector<int > dev_tokens (tokens.begin (), tokens.end ());
100
75
int * dev_tokens_ptr = thrust::raw_pointer_cast (dev_tokens.data ());
101
76
102
77
// Count number of elements to be erased
103
- thrust::device_vector<size_t > num_erased (in_len + 1 );
78
+ thrust::device_vector<size_t > num_erased (in_len + 1 , 0 );
104
79
size_t * num_erased_ptr = thrust::raw_pointer_cast (num_erased.data ());
105
80
auto stream = ctx.cuda_device_context ().stream ();
106
81
LabelErasedIdx<<<(in_len - 1 ) / PADDLE_CUDA_NUM_THREADS + 1 ,
@@ -112,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
112
87
// Copy LoD to GPU
113
88
auto lod0 = lod[0 ];
114
89
auto lod_len = lod0.size ();
115
- thrust::device_vector<size_t > dev_in_lod =
116
- set_device_vector<size_t , paddle::framework::Vector<size_t >>(lod0);
90
+ thrust::device_vector<size_t > dev_in_lod = lod0;
117
91
size_t * dev_in_lod_ptr = thrust::raw_pointer_cast (dev_in_lod.data ());
118
92
119
93
// Calc output LoD
@@ -124,7 +98,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
124
98
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
125
99
126
100
// Set LoD for output
127
- std::vector <size_t > out_lod0 = get_std_vector< size_t >( dev_out_lod) ;
101
+ thrust::host_vector <size_t > out_lod0 = dev_out_lod;
128
102
framework::LoD out_lod;
129
103
out_lod.push_back (out_lod0);
130
104
out->set_lod (out_lod);
@@ -142,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
142
116
} // namespace paddle
143
117
144
118
REGISTER_OP_CUDA_KERNEL (sequence_erase,
145
- paddle::operators::SequenceEraseOpCUDAKernel<int32_t >);
119
+ paddle::operators::SequenceEraseOpCUDAKernel<int32_t >,
120
+ paddle::operators::SequenceEraseOpCUDAKernel<int64_t >);
0 commit comments