@@ -22,8 +22,8 @@ namespace math {
22
22
template <typename T, CopyType Type>
23
23
__global__ void SequencePaddingKernel (
24
24
T* dst, const T* src, const T* pad_value, bool is_constant_pad,
25
- const size_t * seq_offsets, const size_t & seq_num, const size_t & pad_seq_len,
26
- const size_t & step_width, bool norm_by_len, const PadLayout& layout) {
25
+ const size_t * seq_offsets, const size_t seq_num, const size_t pad_seq_len,
26
+ const size_t step_width, bool norm_by_len, const PadLayout layout) {
27
27
size_t seq_idx = blockIdx .y ;
28
28
size_t seq_len = seq_offsets[seq_idx + 1 ] - seq_offsets[seq_idx];
29
29
@@ -43,7 +43,7 @@ __global__ void SequencePaddingKernel(
43
43
dst_data[i] = scale * src_data[i];
44
44
}
45
45
} else if (step_idx < pad_seq_len && Type == kSeqToPad ) {
46
- for (size_t i = threadIdx .x ; i < seq_width ; i += blockDim .x ) {
46
+ for (size_t i = threadIdx .x ; i < step_width ; i += blockDim .x ) {
47
47
dst_data[i] = is_constant_pad ? pad_value[0 ] : pad_value[i];
48
48
}
49
49
}
@@ -54,33 +54,34 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
54
54
public:
55
55
void operator ()(const platform::CUDADeviceContext& context,
56
56
const framework::LoDTensor& seq_tensor,
57
- framework::Tensor * pad_tensor,
57
+ framework::LoDTensor * pad_tensor,
58
58
const framework::LoDTensor& pad_value, int pad_seq_len = -1 ,
59
59
int lod_level = 0 , bool norm_by_times = false ,
60
60
const PadLayout layout = kBatchLengthWidth ) {
61
61
auto seq_lod = seq_tensor.lod ();
62
62
const auto seq_offsets = framework::ToAbsOffset (seq_lod)[lod_level];
63
63
const auto & seq_tensor_dims = seq_tensor.dims ();
64
64
const auto & pad_tensor_dims = pad_tensor->dims ();
65
+ int max_seq_len = MaximumSequenceLength (seq_offsets);
65
66
if (pad_seq_len == -1 ) {
66
- pad_seq_len = MaximumSequenceLength (seq_offsets) ;
67
+ pad_seq_len = max_seq_len ;
67
68
}
68
69
int step_width = seq_tensor.numel () / seq_tensor_dims[0 ];
69
- int seq_num = seq_offset .size () - 1 ;
70
+ int seq_num = seq_offsets .size () - 1 ;
70
71
71
72
CheckDims (seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
72
73
step_width, layout);
73
74
PADDLE_ENFORCE (pad_value.numel () == 1 || pad_value.numel () == step_width,
74
75
" The numel of 'pad_value' can only be 1 or be equal to the "
75
76
" 'step_width'." );
76
77
77
- if (!norm_by_times && seq_num == 1UL && pad_seq_len == - 1 ) {
78
+ if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len ) {
78
79
TensorCopy (seq_tensor, context.GetPlace (), context, pad_tensor);
79
80
pad_tensor->Resize (pad_tensor_dims);
80
81
return ;
81
82
}
82
83
83
- const int64_t kBlockSize = 512 ;
84
+ const int kBlockSize = 512 ;
84
85
85
86
/* At least use 32 threads to copy sequence_width elements,
86
87
* and at least 8 elements for each thread.
@@ -100,8 +101,16 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
100
101
101
102
SequencePaddingKernel<T, kSeqToPad ><<<grid, threads, 0 , context.stream()>>> (
102
103
pad_data, seq_data, pad_value_data, pad_value.numel () == 1 ,
103
- seq_offset .CUDAData (context.GetPlace ()), seq_num, pad_seq_len,
104
+ seq_offsets .CUDAData (context.GetPlace ()), seq_num, pad_seq_len,
104
105
step_width, norm_by_times, layout);
106
+
107
+ if (layout == kBatchLengthWidth ) {
108
+ framework::LoD pad_lod (seq_lod.begin () + lod_level, seq_lod.end ());
109
+ for (size_t i = 0 ; i < pad_lod[0 ].size (); ++i) {
110
+ pad_lod[0 ][i] = i * pad_seq_len;
111
+ }
112
+ pad_tensor->set_lod (pad_lod);
113
+ }
105
114
}
106
115
};
107
116
@@ -116,22 +125,23 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
116
125
auto seq_offsets = framework::ToAbsOffset (seq_tensor->lod ())[lod_level];
117
126
const auto & seq_tensor_dims = seq_tensor->dims ();
118
127
const auto & pad_tensor_dims = pad_tensor.dims ();
128
+ int max_seq_len = MaximumSequenceLength (seq_offsets);
119
129
if (pad_seq_len == -1 ) {
120
- pad_seq_len = MaximumSequenceLength (seq_offsets) ;
130
+ pad_seq_len = max_seq_len ;
121
131
}
122
132
int step_width = seq_tensor->numel () / seq_tensor_dims[0 ];
123
- int seq_num = seq_offset .size () - 1 ;
133
+ int seq_num = seq_offsets .size () - 1 ;
124
134
125
135
CheckDims (seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
126
136
step_width, layout);
127
137
128
- if (!norm_by_times && seq_num == 1UL && pad_seq_len == - 1 ) {
138
+ if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len ) {
129
139
TensorCopy (pad_tensor, context.GetPlace (), context, seq_tensor);
130
140
seq_tensor->Resize (seq_tensor_dims);
131
141
return ;
132
142
}
133
143
134
- const int64_t kBlockSize = 512 ;
144
+ const int kBlockSize = 512 ;
135
145
136
146
/* At least use 32 threads to copy sequence_width elements,
137
147
* and at least 8 elements for each thread.
@@ -150,7 +160,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
150
160
151
161
SequencePaddingKernel<T, kPadToSeq ><<<grid, threads, 0 , context.stream()>>> (
152
162
seq_data, pad_data, nullptr , false ,
153
- seq_offset .CUDAData (context.GetPlace ()), seq_num, pad_seq_len,
163
+ seq_offsets .CUDAData (context.GetPlace ()), seq_num, pad_seq_len,
154
164
step_width, norm_by_times, layout);
155
165
}
156
166
};
0 commit comments