Skip to content

Commit 1715251

Browse files
committed
update CPU sequence_padding functor
1 parent 8b9938a commit 1715251

File tree

4 files changed

+108
-113
lines changed

4 files changed

+108
-113
lines changed

paddle/fluid/operators/math/sequence_padding.cc

Lines changed: 78 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,45 @@ namespace paddle {
1818
namespace operators {
1919
namespace math {
2020

21+
enum CopyType { kSeqToPad, kPadToSeq };
22+
2123
template <typename T>
22-
void CopyDataCPU(framework::LoDTensor* seq_tensor,
23-
framework::Tensor* pad_tensor,
24-
const framework::Vector<size_t>& seq_offset,
25-
const int64_t& max_seq_len, const int64_t& seq_width,
26-
bool seq_to_pad, bool norm_by_len,
27-
OutputLayout output_layout) {
28-
T* seq_data = seq_tensor->data<T>();
29-
T* pad_data = pad_tensor->data<T>();
30-
31-
int64_t seq_num = seq_offset.size() - 1;
32-
33-
for (int64_t i = 0; i < seq_num; ++i) {
34-
int64_t seq_start = seq_offset[i];
35-
int64_t seq_len = seq_offset[i + 1] - seq_start;
36-
T scale = norm_by_len ? (1.0f / static_cast<T>(seq_len)) : 1.0f;
37-
for (int64_t j = 0; j < seq_len; ++j) {
38-
for (int64_t k = 0; k < seq_width; ++k) {
39-
size_t pad_data_idx = 0;
40-
size_t seq_data_idx = (seq_start + j) * seq_width + k;
41-
if (output_layout == kBatchLengthWidth) {
42-
pad_data_idx = (i * max_seq_len + j) * seq_width + k;
43-
} else {
44-
pad_data_idx = (j * seq_num + i) * seq_width + k;
45-
}
46-
if (seq_to_pad) {
47-
pad_data[pad_data_idx] = seq_data[seq_data_idx] * scale;
48-
} else {
49-
seq_data[seq_data_idx] = pad_data[pad_data_idx] * scale;
24+
void CopyValidData(framework::Tensor* dst_tensor,
25+
const framework::Tensor* src_tensor,
26+
const framework::Vector<size_t>& seq_offsets,
27+
int pad_seq_len, int step_width, bool norm_by_len,
28+
CopyType type, PadLayout layout) {
29+
int seq_num = seq_offsets.size() - 1;
30+
const T* src_data = src_tensor->data<T>();
31+
T* dst_data = dst_tensor->data<T>();
32+
33+
int seq_cpy_gap = step_width;
34+
int pad_cpy_gap =
35+
layout == kBatchLengthWidth ? step_width : seq_num * step_width;
36+
for (int seq_idx = 0; seq_idx < seq_num; ++seq_idx) {
37+
int valid_seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
38+
PADDLE_ENFORCE_GE(
39+
pad_seq_len, valid_seq_len,
40+
"The padded sequence length can not be less than its original length.");
41+
int seq_data_offset = seq_offsets[seq_idx] * step_width;
42+
int pad_data_offset = layout == kBatchLengthWidth
43+
? seq_idx * pad_seq_len * step_width
44+
: seq_idx * step_width;
45+
float scale = 1.0f / static_cast<float>(valid_seq_len);
46+
47+
for (int step_idx = 0; step_idx < valid_seq_len; ++step_idx) {
48+
const T* src =
49+
src_data + (type == kSeqToPad ? seq_data_offset : pad_data_offset);
50+
T* dst =
51+
dst_data + (type == kSeqToPad ? pad_data_offset : seq_data_offset);
52+
memcpy(dst, src, step_width * sizeof(T));
53+
if (norm_by_len) {
54+
for (int i = 0; i < step_width; ++i) {
55+
*(dst + i) *= scale;
5056
}
5157
}
58+
seq_data_offset += seq_cpy_gap;
59+
pad_data_offset += pad_cpy_gap;
5260
}
5361
}
5462
}
@@ -58,62 +66,61 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
5866
public:
5967
void operator()(const platform::CPUDeviceContext& context,
6068
const framework::LoDTensor& seq_tensor,
61-
framework::Tensor* pad_tensor,
62-
T pad_value = static_cast<T>(0), bool norm_by_times = false,
63-
size_t lod_level = 0,
64-
OutputLayout output_layout = kBatchLengthWidth) {
65-
CheckLoD(seq_tensor, lod_level);
66-
67-
auto& lod = seq_tensor.lod();
68-
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
69-
69+
framework::LoDTensor* pad_tensor,
70+
std::vector<T> pad_value = {0}, int pad_seq_len = -1,
71+
int lod_level = 0, bool norm_by_times = false,
72+
const PadLayout layout = kBatchLengthWidth) {
73+
auto seq_offsets = framework::ToAbsOffset(seq_tensor.lod())[lod_level];
7074
auto seq_tensor_dims = seq_tensor.dims();
7175
auto pad_tensor_dims = pad_tensor->dims();
72-
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
73-
int64_t seq_num = seq_offset.size() - 1;
74-
int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0];
76+
if (pad_seq_len == -1) {
77+
pad_seq_len = MaximumSequenceLength(seq_offsets);
78+
}
79+
int step_width = seq_tensor.numel() / seq_tensor_dims[0];
7580

76-
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
77-
seq_num, seq_width, output_layout);
81+
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
82+
step_width, layout);
83+
PADDLE_ENFORCE(pad_value.size() == 1 ||
84+
static_cast<int>(pad_value.size()) == step_width,
85+
"The size of 'pad_value' can only be 1 or be equal to the "
86+
"'step_width'.");
7887

79-
T* pad_data = pad_tensor->data<T>();
88+
if (pad_value.size() == 1) {
89+
pad_value = std::vector<T>(step_width, pad_value[0]);
90+
}
8091

81-
memset(pad_data, pad_value, max_seq_len * seq_num * seq_width * sizeof(T));
92+
// fill padding value
93+
T* pad_data = pad_tensor->data<T>();
94+
for (int i = 0; i < pad_tensor->numel() / step_width; ++i) {
95+
memcpy(pad_data, pad_value.data(), step_width * sizeof(T));
96+
}
8297

83-
CopyDataCPU<T>(const_cast<framework::LoDTensor*>(&seq_tensor), pad_tensor,
84-
seq_offset, max_seq_len, seq_width, true /* seq_to_pad */,
85-
norm_by_times, output_layout);
98+
CopyValidData<T>(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len,
99+
step_width, norm_by_times, kSeqToPad, layout);
86100
}
87101
};
88102

89103
template <typename T>
90104
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
91105
public:
92106
void operator()(const platform::CPUDeviceContext& context,
93-
framework::LoDTensor* seq_tensor,
94-
const framework::Tensor& pad_tensor,
95-
bool norm_by_times = false, size_t lod_level = 0,
96-
OutputLayout output_layout = kBatchLengthWidth) {
97-
CheckLoD(*seq_tensor, lod_level);
98-
99-
auto& lod = seq_tensor->lod();
100-
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
101-
102-
auto& seq_tensor_dims = seq_tensor->dims();
103-
auto& pad_tensor_dims = pad_tensor.dims();
104-
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
105-
int64_t seq_num = seq_offset.size() - 1;
106-
int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0];
107-
108-
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
109-
seq_num, seq_width, output_layout);
110-
111-
T* seq_data = seq_tensor->data<T>();
112-
memset(seq_data, static_cast<T>(0), seq_tensor->numel() * sizeof(T));
113-
114-
CopyDataCPU<T>(seq_tensor, const_cast<framework::Tensor*>(&pad_tensor),
115-
seq_offset, max_seq_len, seq_width, false /* seq_to_pad */,
116-
norm_by_times, output_layout);
107+
const framework::LoDTensor& pad_tensor,
108+
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
109+
int lod_level = 0, bool norm_by_times = false,
110+
const PadLayout& layout = kBatchLengthWidth) {
111+
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
112+
auto seq_tensor_dims = seq_tensor->dims();
113+
auto pad_tensor_dims = pad_tensor.dims();
114+
if (pad_seq_len == -1) {
115+
pad_seq_len = MaximumSequenceLength(seq_offsets);
116+
}
117+
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
118+
119+
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
120+
step_width, layout);
121+
122+
CopyValidData<T>(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len,
123+
step_width, norm_by_times, kPadToSeq, layout);
117124
}
118125
};
119126

paddle/fluid/operators/math/sequence_padding.h

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <algorithm>
18+
#include <vector>
1819
#include "paddle/fluid/framework/lod_tensor.h"
1920
#include "paddle/fluid/platform/device_context.h"
2021

2122
namespace paddle {
2223
namespace operators {
2324
namespace math {
2425

25-
enum OutputLayout { kBatchLengthWidth = 0, kLengthBatchWidth };
26+
enum PadLayout { kBatchLengthWidth = 0, kLengthBatchWidth };
2627

2728
inline static size_t MaximumSequenceLength(
2829
const framework::Vector<size_t>& seq_offset) {
@@ -34,35 +35,22 @@ inline static size_t MaximumSequenceLength(
3435
return max_seq_len;
3536
}
3637

37-
inline static void CheckLoD(const framework::LoDTensor& seq_tensor,
38-
const size_t& lod_level) {
39-
PADDLE_ENFORCE(lod_level < seq_tensor.lod().size(),
40-
"Invalid lod level which should be at least 0 and less "
41-
"than maximum lod level of sequence tensor.");
42-
}
43-
4438
inline static void CheckDims(const framework::DDim& seq_tensor_dims,
45-
const size_t& last_offset,
4639
const framework::DDim& pad_tensor_dims,
47-
const int64_t& max_seq_len, const int64_t& seq_num,
48-
const int64_t& seq_width,
49-
const OutputLayout& output_layout) {
50-
PADDLE_ENFORCE_EQ(static_cast<size_t>(seq_tensor_dims[0]), last_offset,
40+
const framework::Vector<size_t>& seq_offset,
41+
int64_t padded_seq_len, int64_t step_width,
42+
const PadLayout& layout) {
43+
PADDLE_ENFORCE_EQ(static_cast<size_t>(seq_tensor_dims[0]), seq_offset.back(),
5144
"Value of 1st dimension of the sequence tensor should be "
5245
"equal to sum of lengths of all sequences.");
5346

54-
PADDLE_ENFORCE_EQ(pad_tensor_dims.size(), 3UL,
55-
"Padded tensor should be a 3-D tensor.");
47+
PADDLE_ENFORCE(seq_tensor_dims.size() == 1 || seq_tensor_dims.size() == 2,
48+
"seq_tensor's rank should be 1 or 2.");
5649

57-
if (output_layout == kBatchLengthWidth) {
58-
PADDLE_ENFORCE_EQ(pad_tensor_dims,
59-
framework::make_ddim({seq_num, max_seq_len, seq_width}));
60-
} else if (output_layout == kLengthBatchWidth) {
61-
PADDLE_ENFORCE_EQ(pad_tensor_dims,
62-
framework::make_ddim({max_seq_len, seq_num, seq_width}));
63-
} else {
64-
PADDLE_THROW("Unsupported output layout.");
65-
}
50+
PADDLE_ENFORCE(seq_tensor_dims.size() + 1 == pad_tensor_dims.size() ||
51+
seq_tensor_dims.size() == pad_tensor_dims.size(),
52+
"pad_tensor's rank should be 1 greater than seq_tensor's "
53+
"rank, or be equal with it.");
6654
}
6755

6856
/*
@@ -94,22 +82,22 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims,
9482
template <typename DeviceContext, typename T>
9583
class PaddingLoDTensorFunctor {
9684
public:
97-
void operator()(const DeviceContext& context,
85+
void operator()(const platform::CPUDeviceContext& context,
9886
const framework::LoDTensor& seq_tensor,
99-
framework::Tensor* pad_tensor,
100-
T pad_value = static_cast<T>(0), bool norm_by_times = false,
101-
size_t lod_level = 0,
102-
OutputLayout output_layout = kBatchLengthWidth);
87+
framework::LoDTensor* pad_tensor,
88+
std::vector<T> pad_value = {0}, int pad_seq_len = -1,
89+
int lod_level = 0, bool norm_by_times = false,
90+
const PadLayout layout = kBatchLengthWidth);
10391
};
10492

10593
template <typename DeviceContext, typename T>
10694
class UnpaddingLoDTensorFunctor {
10795
public:
108-
void operator()(const DeviceContext& context,
109-
framework::LoDTensor* seq_tensor,
110-
const framework::Tensor& pad_tensor,
111-
bool norm_by_times = false, size_t lod_level = 0,
112-
OutputLayout output_layout = kBatchLengthWidth);
96+
void operator()(const platform::CPUDeviceContext& context,
97+
const framework::LoDTensor& pad_tensor,
98+
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
99+
int lod_level = 0, bool norm_by_times = false,
100+
const PadLayout& layout = kBatchLengthWidth);
113101
};
114102

115103
} // namespace math

paddle/fluid/operators/math/sequence_padding_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
2323
paddle::framework::LoDTensor cpu_seq_back;
2424
paddle::framework::LoDTensor seq;
2525
paddle::framework::LoDTensor seq_back;
26-
paddle::framework::Tensor padding;
26+
paddle::framework::LoDTensor padding;
2727

2828
const size_t level = lod.size() - 1;
2929
auto seq_dims =
@@ -56,13 +56,13 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
5656
padding.mutable_data<T>(padding_dims, *place);
5757

5858
paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
59-
*context, seq, &padding, 0, false, 0,
59+
*context, seq, &padding, {0}, -1, 0, false,
6060
paddle::operators::math::kLengthBatchWidth);
6161

6262
seq_back.set_lod(lod);
6363
seq_back.mutable_data<T>(seq_dims, *place);
6464
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
65-
*context, &seq_back, padding, false, 0,
65+
*context, padding, &seq_back, -1, 0, false,
6666
paddle::operators::math::kLengthBatchWidth);
6767

6868
if (paddle::platform::is_cpu_place(*place)) {

paddle/fluid/operators/warpctc_op.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
153153
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
154154

155155
// warpctc needs sequences data stored in transposed padding format
156-
Tensor warpctc_logits;
156+
LoDTensor warpctc_logits;
157157
const size_t max_sequence_length =
158158
math::MaximumSequenceLength(logits_lod[level]);
159159
auto warpctc_logits_dims =
@@ -163,7 +163,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
163163
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
164164
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
165165
ctx.template device_context<DeviceContext>(), *logits, &warpctc_logits,
166-
static_cast<T>(0), false /* norm_by_times */, 0,
166+
{static_cast<T>(0)}, -1, 0, false /* norm_by_times */,
167167
math::kLengthBatchWidth);
168168
const T* warpctc_logits_data = warpctc_logits.data<T>();
169169

@@ -210,15 +210,15 @@ template <typename DeviceContext, typename T>
210210
class WarpCTCGradKernel : public framework::OpKernel<T> {
211211
public:
212212
void Compute(const framework::ExecutionContext& ctx) const override {
213-
auto* warpctc_grad = ctx.Input<Tensor>("WarpCTCGrad");
213+
auto* warpctc_grad = ctx.Input<LoDTensor>("WarpCTCGrad");
214214
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
215215
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
216216

217217
logits_grad->mutable_data<T>(ctx.GetPlace());
218218
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
219219
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
220-
ctx.template device_context<DeviceContext>(), logits_grad,
221-
*warpctc_grad, norm_by_times, 0, math::kLengthBatchWidth);
220+
ctx.template device_context<DeviceContext>(), *warpctc_grad,
221+
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
222222

223223
const T* loss_grad_data = loss_grad->data<T>();
224224
math::ScaleLoDTensorFunctor<DeviceContext, T>()(

0 commit comments

Comments
 (0)