@@ -21,74 +21,74 @@ namespace math {
21
21
22
22
template <typename T, bool Padding>
23
23
__global__ void SequencePaddingKernel (
24
- T* padding_data, T* seq_data, const size_t * abs_offset,
25
- const size_t & seq_num, const size_t & max_seq_len, const size_t & seq_width,
26
- const PaddingLayout& padding_layout, bool norm_by_times = false ,
27
- const T& padding_value = 0 ) {
28
- size_t padding_idx = blockIdx .y ;
29
- size_t seq_start = abs_offset[padding_idx];
30
- size_t seq_len = abs_offset[padding_idx + 1 ] - seq_start;
24
+ T* pad_data, T* seq_data, const size_t * seq_offset, const size_t & seq_num,
25
+ const size_t & max_seq_len, const size_t & seq_width, bool norm_by_times,
26
+ const T& pad_value, const OutputLayout& output_layout) {
27
+ size_t seq_idx = blockIdx .y ;
28
+ size_t seq_start = seq_offset[seq_idx];
29
+ size_t seq_len = seq_offset[seq_idx + 1 ] - seq_start;
31
30
32
- size_t seq_idx = blockIdx .x * blockDim .y + threadIdx .y ;
31
+ size_t seq_step_idx = blockIdx .x * blockDim .y + threadIdx .y ;
33
32
34
- size_t seq_offset = (seq_start + seq_idx ) * seq_width;
33
+ size_t seq_data_offset = (seq_start + seq_step_idx ) * seq_width;
35
34
36
- size_t padding_offset = 0 ;
35
+ size_t pad_data_offset = 0 ;
37
36
38
- if (padding_layout == LENGTH_BATCH_WIDTH ) {
39
- padding_offset = (seq_idx * seq_num + padding_idx ) * seq_width;
37
+ if (output_layout == kLengthBatchWidth ) {
38
+ pad_data_offset = (seq_step_idx * seq_num + seq_idx ) * seq_width;
40
39
} else {
41
- padding_offset = (padding_idx * max_seq_len + seq_idx ) * seq_width;
40
+ pad_data_offset = (seq_idx * max_seq_len + seq_step_idx ) * seq_width;
42
41
}
43
42
44
- if (seq_idx < seq_len) {
43
+ if (seq_step_idx < seq_len) {
45
44
T scale = norm_by_times ? (1 .0f / static_cast <T>(seq_len)) : 1 .0f ;
46
45
if (Padding) {
47
- /* sequence -> padding */
46
+ /* seq -> pad */
48
47
for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
49
- padding_data[padding_offset + i] = scale * seq_data[seq_offset + i];
48
+ pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i];
50
49
}
51
50
} else {
52
- /* padding -> sequence */
51
+ /* pad -> seq */
53
52
for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
54
- seq_data[seq_offset + i] = scale * padding_data[padding_offset + i];
53
+ seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i];
55
54
}
56
55
}
57
- } else if (seq_idx < max_seq_len) {
56
+ } else if (seq_step_idx < max_seq_len) {
58
57
if (Padding) {
59
- /* sequence -> padding */
58
+ /* seq -> pad */
60
59
for (size_t i = threadIdx .x ; i < seq_width; i += blockDim .x ) {
61
- padding_data[padding_offset + i] = padding_value ;
60
+ pad_data[pad_data_offset + i] = pad_value ;
62
61
}
63
62
}
64
63
}
65
64
}
66
65
67
- template <typename T, PaddingLayout padding_layout >
68
- class PaddingLoDTensorFunctor <platform::CUDADeviceContext, T, padding_layout > {
66
+ template <typename T>
67
+ class PaddingLoDTensorFunctor <platform::CUDADeviceContext, T> {
69
68
public:
70
69
void operator ()(const platform::CUDADeviceContext& context,
71
70
const framework::LoDTensor& seq_tensor,
72
- framework::Tensor* padding_tensor,
73
- T padding_value = static_cast <T>(0 ),
74
- bool norm_by_times = false, size_t lod_level = 0) {
75
- ValidateLoD (seq_tensor, lod_level);
71
+ framework::Tensor* pad_tensor,
72
+ T pad_value = static_cast <T>(0 ), bool norm_by_times = false,
73
+ size_t lod_level = 0,
74
+ OutputLayout output_layout = kBatchLengthWidth) {
75
+ CheckLoD (seq_tensor, lod_level);
76
76
77
77
auto & lod = seq_tensor.lod ();
78
- auto & abs_offset = framework::ToAbsOffset (lod)[lod_level];
78
+ auto & seq_offset = framework::ToAbsOffset (lod)[lod_level];
79
79
80
- auto seq_dims = seq_tensor.dims ();
81
- auto padding_dims = padding_tensor ->dims ();
82
- int64_t max_seq_len = MaximumSequenceLength (lod, lod_level );
83
- const int64_t seq_num = abs_offset .size () - 1 ;
84
- const int64_t seq_width = seq_tensor.numel () / seq_dims [0 ];
80
+ auto seq_tensor_dims = seq_tensor.dims ();
81
+ auto pad_tensor_dims = pad_tensor ->dims ();
82
+ int64_t max_seq_len = MaximumSequenceLength (seq_offset );
83
+ int64_t seq_num = seq_offset .size () - 1 ;
84
+ int64_t seq_width = seq_tensor.numel () / seq_tensor_dims [0 ];
85
85
86
- ValidateShape (seq_dims, abs_offset .back (), padding_dims , max_seq_len,
87
- seq_num, seq_width, padding_layout );
86
+ CheckDims (seq_tensor_dims, seq_offset .back (), pad_tensor_dims , max_seq_len,
87
+ seq_num, seq_width, output_layout );
88
88
89
89
if (!norm_by_times && seq_num == 1UL ) {
90
- TensorCopy (seq_tensor, context.GetPlace (), context, padding_tensor );
91
- padding_tensor ->Resize (padding_dims );
90
+ TensorCopy (seq_tensor, context.GetPlace (), context, pad_tensor );
91
+ pad_tensor ->Resize (pad_tensor_dims );
92
92
return ;
93
93
}
94
94
@@ -107,37 +107,40 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T, padding_layout> {
107
107
dim3 grid (grid_dim_x, grid_dim_y);
108
108
109
109
const T* seq_data = seq_tensor.data <T>();
110
- T* padding_data = padding_tensor ->data <T>();
110
+ T* pad_data = pad_tensor ->data <T>();
111
111
112
112
SequencePaddingKernel<T, 1 ><<<grid, threads, 0 , context.stream()>>> (
113
- padding_data , const_cast <T*>(seq_data),
114
- abs_offset .CUDAData (context.GetPlace ()), seq_num, max_seq_len,
115
- seq_width, padding_layout, norm_by_times, padding_value );
113
+ pad_data , const_cast <T*>(seq_data),
114
+ seq_offset .CUDAData (context.GetPlace ()), seq_num, max_seq_len,
115
+ seq_width, norm_by_times, pad_value, output_layout );
116
116
}
117
117
};
118
118
119
- template <typename T, PaddingLayout padding_layout>
120
- class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, T,
121
- padding_layout> {
119
+ template <typename T>
120
+ class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, T> {
122
121
public:
123
122
void operator ()(const platform::CUDADeviceContext& context,
124
123
framework::LoDTensor* seq_tensor,
125
- const framework::Tensor& padding_tensor,
126
- bool norm_by_times = false , size_t lod_level = 0 ) {
127
- ValidateLoD (*seq_tensor, lod_level);
124
+ const framework::Tensor& pad_tensor,
125
+ bool norm_by_times = false , size_t lod_level = 0 ,
126
+ OutputLayout output_layout = kBatchLengthWidth ) {
127
+ CheckLoD (*seq_tensor, lod_level);
128
128
129
129
auto & lod = seq_tensor->lod ();
130
- auto & abs_offset = framework::ToAbsOffset (lod)[lod_level];
130
+ auto & seq_offset = framework::ToAbsOffset (lod)[lod_level];
131
131
132
- auto seq_dims = seq_tensor->dims ();
133
- auto padding_dims = padding_tensor.dims ();
134
- int64_t max_seq_len = MaximumSequenceLength (lod, lod_level);
135
- int64_t seq_num = abs_offset.size () - 1 ;
136
- int64_t seq_width = seq_tensor->numel () / seq_dims[0 ];
132
+ auto seq_tensor_dims = seq_tensor->dims ();
133
+ auto pad_tensor_dims = pad_tensor.dims ();
134
+ int64_t max_seq_len = MaximumSequenceLength (seq_offset);
135
+ int64_t seq_num = seq_offset.size () - 1 ;
136
+ int64_t seq_width = seq_tensor->numel () / seq_tensor_dims[0 ];
137
+
138
+ CheckDims (seq_tensor_dims, seq_offset.back (), pad_tensor_dims, max_seq_len,
139
+ seq_num, seq_width, output_layout);
137
140
138
141
if (!norm_by_times && seq_num == 1UL ) {
139
- TensorCopy (padding_tensor , context.GetPlace (), context, seq_tensor);
140
- seq_tensor->Resize (seq_dims );
142
+ TensorCopy (pad_tensor , context.GetPlace (), context, seq_tensor);
143
+ seq_tensor->Resize (seq_tensor_dims );
141
144
return ;
142
145
}
143
146
@@ -155,20 +158,25 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T,
155
158
size_t grid_dim_y = seq_num;
156
159
dim3 grid (grid_dim_x, grid_dim_y);
157
160
158
- const T* padding_data = padding_tensor .data <T>();
161
+ const T* pad_data = pad_tensor .data <T>();
159
162
T* seq_data = seq_tensor->data <T>();
160
163
161
- SequencePaddingKernel<T, 1 ><<<grid, threads, 0 , context.stream()>>> (
162
- const_cast <T*>(padding_data ), seq_data,
163
- abs_offset .CUDAData (context.GetPlace ()), seq_num, max_seq_len,
164
- seq_width, padding_layout, norm_by_times );
164
+ SequencePaddingKernel<T, 0 ><<<grid, threads, 0 , context.stream()>>> (
165
+ const_cast <T*>(pad_data ), seq_data,
166
+ seq_offset .CUDAData (context.GetPlace ()), seq_num, max_seq_len,
167
+ seq_width, norm_by_times, static_cast <T>( 0 ), output_layout );
165
168
}
166
169
};
167
170
168
- template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, float ,
169
- LENGTH_BATCH_WIDTH>;
170
- template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, float ,
171
- LENGTH_BATCH_WIDTH>;
171
+ template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, int >;
172
+ template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, int64_t >;
173
+ template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, float >;
174
+ template class PaddingLoDTensorFunctor <platform::CUDADeviceContext, double >;
175
+
176
+ template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, int >;
177
+ template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, int64_t >;
178
+ template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, float >;
179
+ template class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, double >;
172
180
173
181
} // namespace math
174
182
} // namespace operators
0 commit comments