Skip to content

Commit ead558b

Browse files
authored
Merge pull request #16256 from tensor-tang/refine/seqenum
refine sequence enumerate op
2 parents c7f1f3e + 50931de commit ead558b

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
3030
"Output(X) of SequenceEnumerate operator should not be null.");
3131

3232
const auto x_dims = ctx->GetInputDim("X");
33-
PADDLE_ENFORCE_EQ(
34-
x_dims.size(), 2,
35-
"Input(X) of SequenceEnumerate operator's rank should be 2.");
36-
PADDLE_ENFORCE_EQ(x_dims[1], 1,
37-
"Input(X) of SequenceEnumerate operator's 2nd "
38-
"dimension should be 1.");
39-
4033
const auto win_size = ctx->Attrs().Get<int>("win_size");
4134
ctx->SetOutputDim("Out", {x_dims[0], win_size});
4235
ctx->ShareLoD("X", "Out");

paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,47 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
2727
auto* in = context.Input<LoDTensor>("X");
2828
auto* out = context.Output<LoDTensor>("Out");
2929
int win_size = context.Attr<int>("win_size");
30-
int pad_value = context.Attr<int>("pad_value");
30+
auto pad_value = static_cast<T>(context.Attr<int>("pad_value"));
3131

3232
auto in_dims = in->dims();
33-
auto in_lod = in->lod();
34-
33+
auto lod0 = in->lod()[0];
3534
PADDLE_ENFORCE_EQ(
36-
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
35+
static_cast<uint64_t>(in_dims[0]), lod0.back(),
3736
"The actual input data's size mismatched with LoD information.");
37+
PADDLE_ENFORCE_EQ(
38+
in_dims.size(), 2UL,
39+
"Input(X) of SequenceEnumerate operator's rank should be 2.");
40+
PADDLE_ENFORCE_EQ(in_dims[1], 1,
41+
"Input(X) of SequenceEnumerate operator's 2nd "
42+
"dimension should be 1.");
3843

3944
// Generate enumerate sequence set
40-
auto lod0 = in_lod[0];
4145
auto in_data = in->data<T>();
4246
out->Resize({in_dims[0], win_size});
47+
out->set_lod(in->lod());
4348
auto out_data = out->mutable_data<T>(context.GetPlace());
4449
for (size_t i = 0; i < lod0.size() - 1; ++i) {
45-
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
46-
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
47-
size_t word_pos = idx + word_idx;
48-
out_data[win_size * idx + word_idx] =
49-
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value;
50+
int start = lod0[i];
51+
int end = lod0[i + 1];
52+
int copy_size = win_size < end - start + 1 ? win_size : end - start + 1;
53+
int mid = end + 1 - copy_size;
54+
int pad_num = win_size - copy_size;
55+
copy_size *= sizeof(T);
56+
for (int idx = start; idx < mid; ++idx) {
57+
std::memcpy(out_data, in_data + idx, copy_size);
58+
out_data += win_size;
59+
}
60+
for (int idx = mid; idx < end; ++idx) {
61+
copy_size -= sizeof(T);
62+
pad_num++;
63+
std::memcpy(out_data, in_data + idx, copy_size);
64+
T* pdata = out_data + copy_size / sizeof(T);
65+
for (int i = 0; i < pad_num; ++i) {
66+
pdata[i] = pad_value;
5067
}
68+
out_data += win_size;
5169
}
5270
}
53-
out->set_lod(in->lod());
5471
}
5572
};
5673

0 commit comments

Comments
 (0)