Skip to content

Commit 0c4697f

Browse files
committed
fix: change to enumerate by sentence
1 parent 4ec1249 commit 0c4697f

File tree

5 files changed

+48
-28
lines changed

5 files changed

+48
-28
lines changed

paddle/fluid/operators/sequence_enumerate_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ The values of the last insufficient part areall filled with the input pad_value.
7272
Case 1:
7373
Input:
7474
X.lod = [[0, 3, 5]]
75-
X.data = [1, 2, 3, 4, 5]
75+
X.data = [[1], [2], [3], [4], [5]]
7676
X.dims = [5, 1]
7777
Attrs:
7878
win_size = 2
7979
pad_value = 0
8080
Output:
8181
Out.lod = [[0, 3, 5]]
82-
Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]]
82+
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
8383
Out.dims = [5, 2]
8484
8585
)DOC");

paddle/fluid/operators/sequence_enumerate_op.cu

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,23 @@ using platform::PADDLE_CUDA_NUM_THREADS;
2323
using LoDTensor = framework::LoDTensor;
2424

2525
template <typename T>
26-
__global__ void CalcOutPut(const T* in_data, const int64_t in_len,
27-
const int64_t win_size, const int64_t pad_value,
28-
T* out_data) {
26+
__global__ void CalcOutPut(const T* in_data, const size_t* in_lod,
27+
const size_t lod_len, const int64_t win_size,
28+
const int64_t pad_value, T* out_data) {
2929
int index = blockIdx.x * blockDim.x + threadIdx.x;
30-
if (index < in_len) {
30+
if (index < in_lod[lod_len - 1]) {
31+
int end_idx = 0;
32+
// Get LoD interval of index
33+
for (int i = 1; i < lod_len; ++i) {
34+
if (index < in_lod[i]) {
35+
end_idx = in_lod[i];
36+
break;
37+
}
38+
}
3139
for (size_t i = 0; i < win_size; ++i) {
3240
int word_pos = index + i;
3341
out_data[index * win_size + i] =
34-
word_pos < in_len ? in_data[word_pos] : pad_value;
42+
word_pos < end_idx ? in_data[word_pos] : pad_value;
3543
}
3644
}
3745
}
@@ -54,13 +62,16 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
5462

5563
/* Generate enumerate sequence set */
5664
auto stream = context.cuda_device_context().stream();
65+
auto lod0 = in_lod[0];
5766
auto in_len = in->numel();
5867
auto in_data = in->data<T>();
5968
auto out_data = out->mutable_data<T>(context.GetPlace());
69+
// Copy LoD to GPU
70+
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
6071
// Calc output tensor
6172
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
6273
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
63-
in_data, in_len, win_size, pad_value, out_data);
74+
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
6475
}
6576
};
6677

paddle/fluid/operators/sequence_enumerate_op.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
3737
"The actual input data's size mismatched with LoD information.");
3838

3939
// Generate enumerate sequence set
40-
auto seq_length = in_dims[0];
40+
auto lod0 = in_lod[0];
4141
auto in_data = in->data<T>();
4242
auto out_data = out->mutable_data<T>(context.GetPlace());
43-
for (int idx = 0; idx < seq_length; ++idx) {
44-
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
45-
int word_pos = idx + word_idx;
46-
out_data[win_size * idx + word_idx] =
47-
word_pos < seq_length ? in_data[word_pos] : pad_value;
43+
for (size_t i = 0; i < lod0.size() - 1; ++i) {
44+
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
45+
for (int word_idx = 0; word_idx < win_size; ++word_idx) {
46+
size_t word_pos = idx + word_idx;
47+
out_data[win_size * idx + word_idx] =
48+
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value;
49+
}
4850
}
4951
}
5052
}

python/paddle/fluid/layers/nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5534,14 +5534,14 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
55345534
Case 1:
55355535
Input:
55365536
X.lod = [[0, 3, 5]]
5537-
X.data = [1, 2, 3, 4, 5]
5537+
X.data = [[1], [2], [3], [4], [5]]
55385538
X.dims = [5, 1]
55395539
Attrs:
55405540
win_size = 2
55415541
pad_value = 0
55425542
Output:
55435543
Out.lod = [[0, 3, 5]]
5544-
Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]]
5544+
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]]
55455545
Out.dims = [5, 2]
55465546
55475547
Args:
@@ -5567,7 +5567,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
55675567
attrs={'win_size': win_size,
55685568
'pad_value': pad_value})
55695569

5570-
5570+
55715571
def sequence_mask(x, maxlen=None, dtype='int64', name=None):
55725572
"""
55735573
**SequenceMask Layer**

python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@
1919
from op_test import OpTest
2020

2121

22-
def sequence_enumerate(input_seq, win_size, pad_value):
22+
def sequence_enumerate(input_seq, in_lod, win_size, pad_value):
23+
lod0 = [0]
24+
for i in range(0, len(in_lod[0])):
25+
lod0.append(lod0[i] + in_lod[0][i])
2326
out_seq = []
24-
for idx in range(0, len(input_seq)):
25-
single_seq = []
26-
for word_idx in range(win_size):
27-
word_pos = idx + word_idx
28-
dat = input_seq[word_pos] if word_pos < len(input_seq) \
27+
for i in range(0, len(lod0) - 1):
28+
for idx in range(lod0[i], lod0[i + 1]):
29+
single_seq = []
30+
for word_idx in range(win_size):
31+
word_pos = idx + word_idx
32+
dat = input_seq[word_pos] if word_pos < lod0[i+1] \
2933
else pad_value
30-
single_seq.append(dat)
31-
out_seq.append(single_seq)
34+
single_seq.append(dat)
35+
out_seq.append(single_seq)
3236
return out_seq
3337

3438

@@ -48,7 +52,8 @@ def init_test_case(self):
4852
self.lod = [[9, 4, 11, 6]]
4953
self.win_size = 2
5054
self.pad_value = 0
51-
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
55+
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
56+
self.pad_value)
5257
self.out_seq = np.array(out_seq).astype("int32")
5358

5459

@@ -58,7 +63,8 @@ def init_test_case(self):
5863
self.lod = [[9, 4, 11, 6]]
5964
self.win_size = 2
6065
self.pad_value = 0
61-
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
66+
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
67+
self.pad_value)
6268
self.out_seq = np.array(out_seq).astype("int64")
6369

6470

@@ -68,7 +74,8 @@ def init_test_case(self):
6874
self.lod = [[9, 4, 11, 6]]
6975
self.win_size = 30
7076
self.pad_value = 0
71-
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
77+
out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size,
78+
self.pad_value)
7279
self.out_seq = np.array(out_seq).astype("int32")
7380

7481

0 commit comments

Comments
 (0)