@@ -23,39 +23,34 @@ using platform::PADDLE_CUDA_NUM_THREADS;
23
23
using LoDTensor = framework::LoDTensor;
24
24
25
25
template <typename T>
26
- __global__ void LabelErasedIdx (const T* in_dat, const int in_len,
27
- const T * tokens, const int tokens_len,
28
- int * num_erased) {
26
+ __global__ void LabelErasedIdx (const T* in_dat, const int64_t in_len,
27
+ const int * tokens, const size_t tokens_len,
28
+ size_t * num_erased) {
29
29
int index = blockIdx .x * blockDim .x + threadIdx .x ;
30
30
if (index < in_len) {
31
- int erased = 0 ;
32
- for (int i = 0 ; i < tokens_len; ++i) {
31
+ for (size_t i = 0 ; i < tokens_len; ++i) {
33
32
if (in_dat[index] == tokens[i]) {
34
- erased = 1 ;
33
+ num_erased[index + 1 ] = 1 ;
34
+ break ;
35
35
}
36
36
}
37
- num_erased[index + 1 ] = erased;
38
- if (index == 0 ) {
39
- num_erased[0 ] = 0 ;
40
- }
41
37
}
42
38
}
43
39
44
- template <typename T>
45
- __global__ void GetOutLod (const T* num_erased, const int * in_lod,
46
- const int lod_len, int * out_lod0) {
40
+ __global__ void GetOutLod (const size_t * num_erased, const size_t * in_lod,
41
+ const size_t lod_len, size_t * out_lod0) {
47
42
int index = blockIdx .x * blockDim .x + threadIdx .x ;
48
43
if (index < lod_len) {
49
44
out_lod0[index] = in_lod[index] - num_erased[in_lod[index]];
50
45
}
51
46
}
52
47
53
48
template <typename T>
54
- __global__ void SetOutput (const T* in_dat, const int in_len,
55
- const int * num_erased, T* out_dat) {
49
+ __global__ void SetOutput (const T* in_dat, const int64_t in_len,
50
+ const size_t * num_erased, T* out_dat) {
56
51
int index = blockIdx .x * blockDim .x + threadIdx .x ;
57
52
if (index < in_len) {
58
- if (in_dat [index] != in_dat [index + 1 ]) {
53
+ if (num_erased [index] == num_erased [index + 1 ]) {
59
54
out_dat[index - num_erased[index]] = in_dat[index];
60
55
}
61
56
}
@@ -72,53 +67,44 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
72
67
PADDLE_ENFORCE_EQ (lod.size (), 1UL , " Only support one level sequence now." );
73
68
PADDLE_ENFORCE_EQ (lod[0 ].back (), (size_t )in->numel (),
74
69
" The actual size mismatches with the LoD information." );
75
- auto tokens = ctx.Attr <std::vector<T>>(" tokens" );
76
- auto tokens_len = tokens.size ();
70
+ auto tokens = ctx.Attr <std::vector<int >>(" tokens" );
77
71
auto in_len = in->numel ();
78
72
auto in_dat = in->data <T>();
79
- auto lod0 = lod[0 ];
80
-
81
- thrust::host_vector<T> host_tokens (tokens_len);
82
- for (size_t i = 0 ; i < tokens.size (); ++i) {
83
- host_tokens[i] = tokens[i];
84
- }
85
- thrust::device_vector<T> dev_tokens = host_tokens;
86
- thrust::device_vector<int > num_erased (in_len + 1 );
87
-
88
- T* dev_tokens_ptr = thrust::raw_pointer_cast (dev_tokens.data ());
89
- int * num_erased_ptr = thrust::raw_pointer_cast (num_erased.data ());
73
+ // Copy tokens to GPU
74
+ thrust::device_vector<int > dev_tokens (tokens.begin (), tokens.end ());
75
+ int * dev_tokens_ptr = thrust::raw_pointer_cast (dev_tokens.data ());
90
76
77
+ // Count number of elements to be erased
78
+ thrust::device_vector<size_t > num_erased (in_len + 1 , 0 );
79
+ size_t * num_erased_ptr = thrust::raw_pointer_cast (num_erased.data ());
91
80
auto stream = ctx.cuda_device_context ().stream ();
92
81
LabelErasedIdx<<<(in_len - 1 ) / PADDLE_CUDA_NUM_THREADS + 1 ,
93
82
PADDLE_CUDA_NUM_THREADS, 0 , stream>>> (
94
- in_dat, in_len, dev_tokens_ptr, tokens_len , num_erased_ptr);
83
+ in_dat, in_len, dev_tokens_ptr, tokens. size () , num_erased_ptr);
95
84
thrust::inclusive_scan (num_erased.begin () + 1 , num_erased.end (),
96
85
num_erased.begin () + 1 );
97
86
98
- // Calc LoD
87
+ // Copy LoD to GPU
88
+ auto lod0 = lod[0 ];
99
89
auto lod_len = lod0.size ();
100
- thrust::host_vector<int > host_lod (lod_len);
101
- for (size_t i = 0 ; i < lod_len; ++i) {
102
- host_lod[i] = lod0[i];
103
- }
104
- thrust::device_vector<int > dev_in_lod = host_lod;
105
- thrust::device_vector<int > dev_out_lod (lod_len);
106
- int * dev_in_lod_ptr = thrust::raw_pointer_cast (dev_in_lod.data ());
107
- int * dev_out_lod_ptr = thrust::raw_pointer_cast (dev_out_lod.data ());
90
+ thrust::device_vector<size_t > dev_in_lod = lod0;
91
+ size_t * dev_in_lod_ptr = thrust::raw_pointer_cast (dev_in_lod.data ());
92
+
93
+ // Calc output LoD
94
+ thrust::device_vector<size_t > dev_out_lod (lod_len);
95
+ size_t * dev_out_lod_ptr = thrust::raw_pointer_cast (dev_out_lod.data ());
108
96
GetOutLod<<<(lod_len - 1 ) / PADDLE_CUDA_NUM_THREADS + 1 ,
109
97
PADDLE_CUDA_NUM_THREADS, 0 , stream>>> (
110
98
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
111
- thrust::host_vector<int > host_out_lod = dev_out_lod;
112
- std::vector<int > out_lod0 (lod_len, 0 );
113
- for (size_t i = 0 ; i < lod_len; i++) {
114
- out_lod0[i] = host_out_lod[i];
115
- }
99
+
100
+ // Set LoD for output
101
+ thrust::host_vector<size_t > out_lod0 = dev_out_lod;
116
102
framework::LoD out_lod;
117
103
out_lod.push_back (out_lod0);
118
104
out->set_lod (out_lod);
119
105
120
106
// Set output
121
- out->Resize ({out_lod0.back (), 1 });
107
+ out->Resize ({static_cast < int64_t >( out_lod0.back () ), 1 });
122
108
auto out_dat = out->mutable_data <T>(ctx.GetPlace ());
123
109
SetOutput<<<(in_len - 1 ) / PADDLE_CUDA_NUM_THREADS + 1 ,
124
110
PADDLE_CUDA_NUM_THREADS, 0 , stream>>> (in_dat, in_len,
@@ -130,4 +116,5 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
130
116
} // namespace paddle
131
117
132
118
REGISTER_OP_CUDA_KERNEL (sequence_erase,
133
- paddle::operators::SequenceEraseOpCUDAKernel<int32_t >);
119
+ paddle::operators::SequenceEraseOpCUDAKernel<int32_t >,
120
+ paddle::operators::SequenceEraseOpCUDAKernel<int64_t >);
0 commit comments