@@ -19,46 +19,32 @@ namespace paddle {
19
19
namespace operators {
20
20
namespace math {
21
21
22
- template <typename T, bool Padding >
22
+ template <typename T, CopyType Type >
23
23
__global__ void SequencePaddingKernel (
24
- T* pad_data, T* seq_data , const size_t * seq_offset, const size_t & seq_num ,
25
- const size_t & max_seq_len , const size_t & seq_width, bool norm_by_times ,
26
- const T& pad_value, const OutputLayout& output_layout ) {
24
+ T* dst, const T* src , const T* pad_value, bool is_constant_pad ,
25
+ const size_t * seq_offsets , const size_t & seq_num, const size_t & pad_seq_len ,
26
+ const size_t & step_width, bool norm_by_len, const PadLayout& layout ) {
27
27
size_t seq_idx = blockIdx .y ;
28
- size_t seq_start = seq_offset[seq_idx];
29
- size_t seq_len = seq_offset[seq_idx + 1 ] - seq_start;
30
-
31
- size_t seq_step_idx = blockIdx .x * blockDim .y + threadIdx .y ;
32
-
33
- size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width;
34
-
35
- size_t pad_data_offset = 0 ;
36
-
37
- if (output_layout == kLengthBatchWidth ) {
38
- pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width;
39
- } else {
40
- pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width;
41
- }
42
-
43
- if (seq_step_idx < seq_len) {
44
- T scale = norm_by_times ? (1 .0f / static_cast <T>(seq_len)) : 1 .0f ;
45
- if (Padding) {
46
- /* seq -> pad */
47
- for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
48
- pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i];
49
- }
50
- } else {
51
- /* pad -> seq */
52
- for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
53
- seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i];
54
- }
28
+ size_t seq_len = seq_offsets[seq_idx + 1 ] - seq_offsets[seq_idx];
29
+
30
+ size_t step_idx = blockIdx .x * blockDim .y + threadIdx .y ;
31
+ size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
32
+ size_t pad_data_offset = layout == kBatchLengthWidth
33
+ ? (seq_idx * pad_seq_len + step_idx) * step_width
34
+ : (step_idx * seq_num + seq_idx) * step_width;
35
+
36
+ T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
37
+ const T* src_data =
38
+ src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);
39
+
40
+ if (step_idx < seq_len) {
41
+ float scale = norm_by_len ? (1 .0f / static_cast <float >(seq_len)) : 1 .0f ;
42
+ for (size_t i = threadIdx .x ; i < step_width; i += blockDim .x ) {
43
+ dst_data[i] = scale * src_data[i];
55
44
}
56
- } else if (seq_step_idx < max_seq_len) {
57
- if (Padding) {
58
- /* seq -> pad */
59
- for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
60
- pad_data[pad_data_offset + i] = pad_value;
61
- }
45
+ } else if (step_idx < pad_seq_len && Type == kSeqToPad ) {
46
+ for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
47
+ dst_data[i] = is_constant_pad ? pad_value[0 ] : pad_value[i];
62
48
}
63
49
}
64
50
}
@@ -69,24 +55,26 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
69
55
void operator ()(const platform::CUDADeviceContext& context,
70
56
const framework::LoDTensor& seq_tensor,
71
57
framework::Tensor* pad_tensor,
72
- T pad_value = static_cast <T>(0 ), bool norm_by_times = false,
73
- size_t lod_level = 0,
74
- OutputLayout output_layout = kBatchLengthWidth) {
75
- CheckLoD (seq_tensor, lod_level);
76
-
77
- auto & lod = seq_tensor.lod ();
78
- auto & seq_offset = framework::ToAbsOffset (lod)[lod_level];
79
-
80
- auto seq_tensor_dims = seq_tensor.dims ();
81
- auto pad_tensor_dims = pad_tensor->dims ();
82
- int64_t max_seq_len = MaximumSequenceLength (seq_offset);
83
- int64_t seq_num = seq_offset.size () - 1 ;
84
- int64_t seq_width = seq_tensor.numel () / seq_tensor_dims[0 ];
58
+ const framework::LoDTensor& pad_value, int pad_seq_len = -1 ,
59
+ int lod_level = 0 , bool norm_by_times = false ,
60
+ const PadLayout layout = kBatchLengthWidth ) {
61
+ auto seq_lod = seq_tensor.lod ();
62
+ const auto seq_offsets = framework::ToAbsOffset (seq_lod)[lod_level];
63
+ const auto & seq_tensor_dims = seq_tensor.dims ();
64
+ const auto & pad_tensor_dims = pad_tensor->dims ();
65
+ if (pad_seq_len == -1 ) {
66
+ pad_seq_len = MaximumSequenceLength (seq_offsets);
67
+ }
68
+ int step_width = seq_tensor.numel () / seq_tensor_dims[0 ];
69
+ int seq_num = seq_offset.size () - 1 ;
85
70
86
- CheckDims (seq_tensor_dims, seq_offset.back (), pad_tensor_dims, max_seq_len,
87
- seq_num, seq_width, output_layout);
71
+ CheckDims (seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
72
+ step_width, layout);
73
+ PADDLE_ENFORCE (pad_value.numel () == 1 || pad_value.numel () == step_width,
74
+ " The numel of 'pad_value' can only be 1 or be equal to the "
75
+ " 'step_width'." );
88
76
89
- if (!norm_by_times && seq_num == 1UL ) {
77
+ if (!norm_by_times && seq_num == 1UL && pad_seq_len == - 1 ) {
90
78
TensorCopy (seq_tensor, context.GetPlace (), context, pad_tensor);
91
79
pad_tensor->Resize (pad_tensor_dims);
92
80
return ;
@@ -98,47 +86,46 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
98
86
* and at least 8 elements for each thread.
99
87
*/
100
88
size_t block_dim_x =
101
- std::min (((((seq_width + 7 ) >> 3 ) + 31 ) >> 5 ) << 5 , kBlockSize );
89
+ std::min (((((step_width + 7 ) >> 3 ) + 31 ) >> 5 ) << 5 , kBlockSize );
102
90
size_t block_dim_y = kBlockSize / block_dim_x;
103
91
dim3 threads (block_dim_x, block_dim_y);
104
92
105
- size_t grid_dim_x = (max_seq_len + block_dim_y - 1 ) / block_dim_y;
93
+ size_t grid_dim_x = (pad_seq_len + block_dim_y - 1 ) / block_dim_y;
106
94
size_t grid_dim_y = seq_num;
107
95
dim3 grid (grid_dim_x, grid_dim_y);
108
96
109
97
const T* seq_data = seq_tensor.data <T>();
110
98
T* pad_data = pad_tensor->data <T>();
99
+ const T* pad_value_data = pad_value.data <T>();
111
100
112
- SequencePaddingKernel<T, 1 ><<<grid, threads, 0 , context.stream()>>> (
113
- pad_data, const_cast <T*>( seq_data) ,
114
- seq_offset.CUDAData (context.GetPlace ()), seq_num, max_seq_len ,
115
- seq_width , norm_by_times, pad_value, output_layout );
101
+ SequencePaddingKernel<T, kSeqToPad ><<<grid, threads, 0 , context.stream()>>> (
102
+ pad_data, seq_data, pad_value_data, pad_value. numel () == 1 ,
103
+ seq_offset.CUDAData (context.GetPlace ()), seq_num, pad_seq_len ,
104
+ step_width , norm_by_times, layout );
116
105
}
117
106
};
118
107
119
108
template <typename T>
120
109
class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, T> {
121
110
public:
122
111
void operator ()(const platform::CUDADeviceContext& context,
123
- framework::LoDTensor* seq_tensor,
124
- const framework::Tensor& pad_tensor,
125
- bool norm_by_times = false , size_t lod_level = 0 ,
126
- OutputLayout output_layout = kBatchLengthWidth ) {
127
- CheckLoD (*seq_tensor, lod_level);
128
-
129
- auto & lod = seq_tensor->lod ();
130
- auto & seq_offset = framework::ToAbsOffset (lod)[lod_level];
131
-
132
- auto seq_tensor_dims = seq_tensor->dims ();
133
- auto pad_tensor_dims = pad_tensor.dims ();
134
- int64_t max_seq_len = MaximumSequenceLength (seq_offset);
135
- int64_t seq_num = seq_offset.size () - 1 ;
136
- int64_t seq_width = seq_tensor->numel () / seq_tensor_dims[0 ];
112
+ const framework::LoDTensor& pad_tensor,
113
+ framework::LoDTensor* seq_tensor, int pad_seq_len = -1 ,
114
+ int lod_level = 0 , bool norm_by_times = false ,
115
+ const PadLayout layout = kBatchLengthWidth ) {
116
+ auto seq_offsets = framework::ToAbsOffset (seq_tensor->lod ())[lod_level];
117
+ const auto & seq_tensor_dims = seq_tensor->dims ();
118
+ const auto & pad_tensor_dims = pad_tensor.dims ();
119
+ if (pad_seq_len == -1 ) {
120
+ pad_seq_len = MaximumSequenceLength (seq_offsets);
121
+ }
122
+ int step_width = seq_tensor->numel () / seq_tensor_dims[0 ];
123
+ int seq_num = seq_offset.size () - 1 ;
137
124
138
- CheckDims (seq_tensor_dims, seq_offset. back (), pad_tensor_dims, max_seq_len ,
139
- seq_num, seq_width, output_layout );
125
+ CheckDims (seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len ,
126
+ step_width, layout );
140
127
141
- if (!norm_by_times && seq_num == 1UL ) {
128
+ if (!norm_by_times && seq_num == 1UL && pad_seq_len == - 1 ) {
142
129
TensorCopy (pad_tensor, context.GetPlace (), context, seq_tensor);
143
130
seq_tensor->Resize (seq_tensor_dims);
144
131
return ;
@@ -150,21 +137,21 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
150
137
* and at least 8 elements for each thread.
151
138
*/
152
139
size_t block_dim_x =
153
- std::min (((((seq_width + 7 ) >> 3 ) + 31 ) >> 5 ) << 5 , kBlockSize );
140
+ std::min (((((step_width + 7 ) >> 3 ) + 31 ) >> 5 ) << 5 , kBlockSize );
154
141
size_t block_dim_y = kBlockSize / block_dim_x;
155
142
dim3 threads (block_dim_x, block_dim_y);
156
143
157
- size_t grid_dim_x = (max_seq_len + block_dim_y - 1 ) / block_dim_y;
144
+ size_t grid_dim_x = (pad_seq_len + block_dim_y - 1 ) / block_dim_y;
158
145
size_t grid_dim_y = seq_num;
159
146
dim3 grid (grid_dim_x, grid_dim_y);
160
147
161
148
const T* pad_data = pad_tensor.data <T>();
162
149
T* seq_data = seq_tensor->data <T>();
163
150
164
- SequencePaddingKernel<T, 0 ><<<grid, threads, 0 , context.stream()>>> (
165
- const_cast <T*>( pad_data), seq_data ,
166
- seq_offset.CUDAData (context.GetPlace ()), seq_num, max_seq_len ,
167
- seq_width , norm_by_times, static_cast <T>( 0 ), output_layout );
151
+ SequencePaddingKernel<T, kPadToSeq ><<<grid, threads, 0 , context.stream()>>> (
152
+ seq_data, pad_data, nullptr , false ,
153
+ seq_offset.CUDAData (context.GetPlace ()), seq_num, pad_seq_len ,
154
+ step_width , norm_by_times, layout );
168
155
}
169
156
};
170
157
0 commit comments