Skip to content

Commit 9db40b8

Browse files
committed
Complete sequence_padding GPU kernel
1 parent c3de69a commit 9db40b8

File tree

7 files changed

+113
-104
lines changed

7 files changed

+113
-104
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ op_library(unsqueeze_op DEPS reshape_op)
282282
op_library(squeeze_op DEPS reshape_op)
283283
op_library(extract_rows_op DEPS memory)
284284
op_library(flatten_op DEPS reshape_op)
285+
op_library(sequence_pad_op DEPS sequence_padding)
285286

286287
if (WITH_GPU)
287288
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/math/sequence_padding.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ namespace paddle {
1818
namespace operators {
1919
namespace math {
2020

21-
enum CopyType { kSeqToPad, kPadToSeq };
22-
2321
template <typename T>
2422
void CopyValidData(framework::Tensor* dst_tensor,
2523
const framework::Tensor* src_tensor,
@@ -67,7 +65,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
6765
void operator()(const platform::CPUDeviceContext& context,
6866
const framework::LoDTensor& seq_tensor,
6967
framework::LoDTensor* pad_tensor,
70-
std::vector<T> pad_value = {0}, int pad_seq_len = -1,
68+
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
7169
int lod_level = 0, bool norm_by_times = false,
7270
const PadLayout layout = kBatchLengthWidth) {
7371
auto seq_lod = seq_tensor.lod();
@@ -81,19 +79,21 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
8179

8280
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
8381
step_width, layout);
84-
PADDLE_ENFORCE(pad_value.size() == 1 ||
85-
static_cast<int>(pad_value.size()) == step_width,
86-
"The size of 'pad_value' can only be 1 or be equal to the "
82+
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width,
83+
"The numel of 'pad_value' can only be 1 or be equal to the "
8784
"'step_width'.");
8885

89-
if (pad_value.size() == 1) {
90-
pad_value = std::vector<T>(step_width, pad_value[0]);
91-
}
92-
9386
// fill padding value
9487
T* pad_data = pad_tensor->data<T>();
95-
for (int i = 0; i < pad_tensor->numel(); i += step_width) {
96-
memcpy(pad_data + i, pad_value.data(), step_width * sizeof(T));
88+
const T* pad_value_data = pad_value.data<T>();
89+
if (pad_value.numel() == 1) {
90+
for (int i = 0; i < pad_tensor->numel(); ++i) {
91+
pad_data[i] = *pad_value_data;
92+
}
93+
} else {
94+
for (int i = 0; i < pad_tensor->numel(); i += step_width) {
95+
memcpy(pad_data + i, pad_value_data, step_width * sizeof(T));
96+
}
9797
}
9898

9999
CopyValidData<T>(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len,
@@ -117,7 +117,7 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
117117
const framework::LoDTensor& pad_tensor,
118118
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
119119
int lod_level = 0, bool norm_by_times = false,
120-
const PadLayout& layout = kBatchLengthWidth) {
120+
const PadLayout layout = kBatchLengthWidth) {
121121
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
122122
const auto& seq_tensor_dims = seq_tensor->dims();
123123
const auto& pad_tensor_dims = pad_tensor.dims();

paddle/fluid/operators/math/sequence_padding.cu

Lines changed: 69 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,32 @@ namespace paddle {
1919
namespace operators {
2020
namespace math {
2121

22-
template <typename T, bool Padding>
22+
template <typename T, CopyType Type>
2323
__global__ void SequencePaddingKernel(
24-
T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num,
25-
const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times,
26-
const T& pad_value, const OutputLayout& output_layout) {
24+
T* dst, const T* src, const T* pad_value, bool is_constant_pad,
25+
const size_t* seq_offsets, const size_t& seq_num, const size_t& pad_seq_len,
26+
const size_t& step_width, bool norm_by_len, const PadLayout& layout) {
2727
size_t seq_idx = blockIdx.y;
28-
size_t seq_start = seq_offset[seq_idx];
29-
size_t seq_len = seq_offset[seq_idx + 1] - seq_start;
30-
31-
size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y;
32-
33-
size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width;
34-
35-
size_t pad_data_offset = 0;
36-
37-
if (output_layout == kLengthBatchWidth) {
38-
pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width;
39-
} else {
40-
pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width;
41-
}
42-
43-
if (seq_step_idx < seq_len) {
44-
T scale = norm_by_times ? (1.0f / static_cast<T>(seq_len)) : 1.0f;
45-
if (Padding) {
46-
/* seq -> pad */
47-
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
48-
pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i];
49-
}
50-
} else {
51-
/* pad -> seq */
52-
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
53-
seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i];
54-
}
28+
size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
29+
30+
size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y;
31+
size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
32+
size_t pad_data_offset = layout == kBatchLengthWidth
33+
? (seq_idx * pad_seq_len + step_idx) * step_width
34+
: (step_idx * seq_num + seq_idx) * step_width;
35+
36+
T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
37+
const T* src_data =
38+
src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);
39+
40+
if (step_idx < seq_len) {
41+
float scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
42+
for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
43+
dst_data[i] = scale * src_data[i];
5544
}
56-
} else if (seq_step_idx < max_seq_len) {
57-
if (Padding) {
58-
/* seq -> pad */
59-
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
60-
pad_data[pad_data_offset + i] = pad_value;
61-
}
45+
} else if (step_idx < pad_seq_len && Type == kSeqToPad) {
46+
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
47+
dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
6248
}
6349
}
6450
}
@@ -69,24 +55,26 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
6955
void operator()(const platform::CUDADeviceContext& context,
7056
const framework::LoDTensor& seq_tensor,
7157
framework::Tensor* pad_tensor,
72-
T pad_value = static_cast<T>(0), bool norm_by_times = false,
73-
size_t lod_level = 0,
74-
OutputLayout output_layout = kBatchLengthWidth) {
75-
CheckLoD(seq_tensor, lod_level);
76-
77-
auto& lod = seq_tensor.lod();
78-
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
79-
80-
auto seq_tensor_dims = seq_tensor.dims();
81-
auto pad_tensor_dims = pad_tensor->dims();
82-
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
83-
int64_t seq_num = seq_offset.size() - 1;
84-
int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0];
58+
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
59+
int lod_level = 0, bool norm_by_times = false,
60+
const PadLayout layout = kBatchLengthWidth) {
61+
auto seq_lod = seq_tensor.lod();
62+
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
63+
const auto& seq_tensor_dims = seq_tensor.dims();
64+
const auto& pad_tensor_dims = pad_tensor->dims();
65+
if (pad_seq_len == -1) {
66+
pad_seq_len = MaximumSequenceLength(seq_offsets);
67+
}
68+
int step_width = seq_tensor.numel() / seq_tensor_dims[0];
69+
int seq_num = seq_offset.size() - 1;
8570

86-
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
87-
seq_num, seq_width, output_layout);
71+
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
72+
step_width, layout);
73+
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width,
74+
"The numel of 'pad_value' can only be 1 or be equal to the "
75+
"'step_width'.");
8876

89-
if (!norm_by_times && seq_num == 1UL) {
77+
if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) {
9078
TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
9179
pad_tensor->Resize(pad_tensor_dims);
9280
return;
@@ -98,47 +86,46 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
9886
* and at least 8 elements for each thread.
9987
*/
10088
size_t block_dim_x =
101-
std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
89+
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
10290
size_t block_dim_y = kBlockSize / block_dim_x;
10391
dim3 threads(block_dim_x, block_dim_y);
10492

105-
size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y;
93+
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
10694
size_t grid_dim_y = seq_num;
10795
dim3 grid(grid_dim_x, grid_dim_y);
10896

10997
const T* seq_data = seq_tensor.data<T>();
11098
T* pad_data = pad_tensor->data<T>();
99+
const T* pad_value_data = pad_value.data<T>();
111100

112-
SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
113-
pad_data, const_cast<T*>(seq_data),
114-
seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
115-
seq_width, norm_by_times, pad_value, output_layout);
101+
SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
102+
pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
103+
seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
104+
step_width, norm_by_times, layout);
116105
}
117106
};
118107

119108
template <typename T>
120109
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
121110
public:
122111
void operator()(const platform::CUDADeviceContext& context,
123-
framework::LoDTensor* seq_tensor,
124-
const framework::Tensor& pad_tensor,
125-
bool norm_by_times = false, size_t lod_level = 0,
126-
OutputLayout output_layout = kBatchLengthWidth) {
127-
CheckLoD(*seq_tensor, lod_level);
128-
129-
auto& lod = seq_tensor->lod();
130-
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
131-
132-
auto seq_tensor_dims = seq_tensor->dims();
133-
auto pad_tensor_dims = pad_tensor.dims();
134-
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
135-
int64_t seq_num = seq_offset.size() - 1;
136-
int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0];
112+
const framework::LoDTensor& pad_tensor,
113+
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
114+
int lod_level = 0, bool norm_by_times = false,
115+
const PadLayout layout = kBatchLengthWidth) {
116+
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
117+
const auto& seq_tensor_dims = seq_tensor->dims();
118+
const auto& pad_tensor_dims = pad_tensor.dims();
119+
if (pad_seq_len == -1) {
120+
pad_seq_len = MaximumSequenceLength(seq_offsets);
121+
}
122+
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
123+
int seq_num = seq_offset.size() - 1;
137124

138-
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
139-
seq_num, seq_width, output_layout);
125+
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
126+
step_width, layout);
140127

141-
if (!norm_by_times && seq_num == 1UL) {
128+
if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) {
142129
TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
143130
seq_tensor->Resize(seq_tensor_dims);
144131
return;
@@ -150,21 +137,21 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
150137
* and at least 8 elements for each thread.
151138
*/
152139
size_t block_dim_x =
153-
std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
140+
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
154141
size_t block_dim_y = kBlockSize / block_dim_x;
155142
dim3 threads(block_dim_x, block_dim_y);
156143

157-
size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y;
144+
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
158145
size_t grid_dim_y = seq_num;
159146
dim3 grid(grid_dim_x, grid_dim_y);
160147

161148
const T* pad_data = pad_tensor.data<T>();
162149
T* seq_data = seq_tensor->data<T>();
163150

164-
SequencePaddingKernel<T, 0><<<grid, threads, 0, context.stream()>>>(
165-
const_cast<T*>(pad_data), seq_data,
166-
seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
167-
seq_width, norm_by_times, static_cast<T>(0), output_layout);
151+
SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
152+
seq_data, pad_data, nullptr, false,
153+
seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
154+
step_width, norm_by_times, layout);
168155
}
169156
};
170157

paddle/fluid/operators/math/sequence_padding.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ namespace math {
2525

2626
enum PadLayout { kBatchLengthWidth = 0, kLengthBatchWidth };
2727

28+
enum CopyType { kSeqToPad, kPadToSeq };
29+
2830
inline static size_t MaximumSequenceLength(
2931
const framework::Vector<size_t>& seq_offset) {
3032
size_t seq_num = seq_offset.size() - 1;
@@ -82,7 +84,7 @@ class PaddingLoDTensorFunctor {
8284
void operator()(const platform::CPUDeviceContext& context,
8385
const framework::LoDTensor& seq_tensor,
8486
framework::LoDTensor* pad_tensor,
85-
std::vector<T> pad_value = {0}, int pad_seq_len = -1,
87+
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
8688
int lod_level = 0, bool norm_by_times = false,
8789
const PadLayout layout = kBatchLengthWidth);
8890
};
@@ -94,7 +96,7 @@ class UnpaddingLoDTensorFunctor {
9496
const framework::LoDTensor& pad_tensor,
9597
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
9698
int lod_level = 0, bool norm_by_times = false,
97-
const PadLayout& layout = kBatchLengthWidth);
99+
const PadLayout layout = kBatchLengthWidth);
98100
};
99101

100102
} // namespace math

paddle/fluid/operators/math/sequence_padding_test.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
2424
paddle::framework::LoDTensor seq;
2525
paddle::framework::LoDTensor seq_back;
2626
paddle::framework::LoDTensor padding;
27+
paddle::framework::LoDTensor cpu_pad_value;
28+
paddle::framework::LoDTensor pad_value;
2729

2830
const size_t level = lod.size() - 1;
2931
auto seq_dims =
@@ -55,8 +57,17 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
5557

5658
padding.mutable_data<T>(padding_dims, *place);
5759

60+
T* pad_value_data =
61+
cpu_pad_value.mutable_data<T>({1}, paddle::platform::CPUPlace());
62+
*pad_value_data = static_cast<T>(0);
63+
if (paddle::platform::is_cpu_place(*place)) {
64+
pad_value = cpu_pad_value;
65+
} else {
66+
TensorCopySync(cpu_pad_value, *place, &pad_value);
67+
}
68+
5869
paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
59-
*context, seq, &padding, {0}, -1, 0, false,
70+
*context, seq, &padding, pad_value, -1, 0, false,
6071
paddle::operators::math::kLengthBatchWidth);
6172

6273
seq_back.set_lod(lod);

paddle/fluid/operators/sequence_pad_op.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,11 @@ class SequencePadOpKernel : public framework::OpKernel<T> {
3535
out->mutable_data<T>(ctx.GetPlace());
3636

3737
const auto* pad_value = ctx.Input<LoDTensor>("PadValue");
38-
const T* pad_value_data = pad_value->data<T>();
39-
std::vector<T> pad_value_vec(pad_value_data,
40-
pad_value_data + pad_value->numel());
4138

4239
int padded_length = ctx.Attr<int>("padded_length");
4340

4441
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
45-
ctx.template device_context<DeviceContext>(), *x, out, pad_value_vec,
42+
ctx.template device_context<DeviceContext>(), *x, out, *pad_value,
4643
padded_length, 0, false, math::kBatchLengthWidth);
4744
}
4845
};

paddle/fluid/operators/warpctc_op.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,21 @@ class WarpCTCKernel : public framework::OpKernel<T> {
161161
static_cast<int64_t>(num_sequences),
162162
static_cast<int64_t>(sequence_width)});
163163
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
164+
165+
LoDTensor cpu_pad_value;
166+
T* pad_value_data =
167+
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
168+
*pad_value_data = static_cast<T>(0);
169+
LoDTensor pad_value;
170+
if (platform::is_cpu_place(ctx.GetPlace())) {
171+
pad_value = cpu_pad_value;
172+
} else {
173+
TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value);
174+
}
175+
164176
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
165177
ctx.template device_context<DeviceContext>(), *logits, &warpctc_logits,
166-
{static_cast<T>(0)}, -1, 0, false /* norm_by_times */,
167-
math::kLengthBatchWidth);
178+
pad_value, -1, 0, false /* norm_by_times */, math::kLengthBatchWidth);
168179
const T* warpctc_logits_data = warpctc_logits.data<T>();
169180

170181
std::vector<int> warpctc_label_lengths(num_sequences);

0 commit comments

Comments
 (0)