@@ -27,30 +27,47 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
27
27
auto * in = context.Input <LoDTensor>(" X" );
28
28
auto * out = context.Output <LoDTensor>(" Out" );
29
29
int win_size = context.Attr <int >(" win_size" );
30
- int pad_value = context.Attr <int >(" pad_value" );
30
+ auto pad_value = static_cast <T>( context.Attr <int >(" pad_value" ) );
31
31
32
32
auto in_dims = in->dims ();
33
- auto in_lod = in->lod ();
34
-
33
+ auto lod0 = in->lod ()[0 ];
35
34
PADDLE_ENFORCE_EQ (
36
- static_cast <uint64_t >(in_dims[0 ]), in_lod[ 0 ] .back (),
35
+ static_cast <uint64_t >(in_dims[0 ]), lod0 .back (),
37
36
" The actual input data's size mismatched with LoD information." );
37
+ PADDLE_ENFORCE_EQ (
38
+ in_dims.size (), 2UL ,
39
+ " Input(X) of SequenceEnumerate operator's rank should be 2." );
40
+ PADDLE_ENFORCE_EQ (in_dims[1 ], 1 ,
41
+ " Input(X) of SequenceEnumerate operator's 2nd "
42
+ " dimension should be 1." );
38
43
39
44
// Generate enumerate sequence set
40
- auto lod0 = in_lod[0 ];
41
45
auto in_data = in->data <T>();
42
46
out->Resize ({in_dims[0 ], win_size});
47
+ out->set_lod (in->lod ());
43
48
auto out_data = out->mutable_data <T>(context.GetPlace ());
44
49
for (size_t i = 0 ; i < lod0.size () - 1 ; ++i) {
45
- for (size_t idx = lod0[i]; idx < lod0[i + 1 ]; ++idx) {
46
- for (int word_idx = 0 ; word_idx < win_size; ++word_idx) {
47
- size_t word_pos = idx + word_idx;
48
- out_data[win_size * idx + word_idx] =
49
- word_pos < lod0[i + 1 ] ? in_data[word_pos] : pad_value;
50
+ int start = lod0[i];
51
+ int end = lod0[i + 1 ];
52
+ int copy_size = win_size < end - start + 1 ? win_size : end - start + 1 ;
53
+ int mid = end + 1 - copy_size;
54
+ int pad_num = win_size - copy_size;
55
+ copy_size *= sizeof (T);
56
+ for (int idx = start; idx < mid; ++idx) {
57
+ std::memcpy (out_data, in_data + idx, copy_size);
58
+ out_data += win_size;
59
+ }
60
+ for (int idx = mid; idx < end; ++idx) {
61
+ copy_size -= sizeof (T);
62
+ pad_num++;
63
+ std::memcpy (out_data, in_data + idx, copy_size);
64
+ T* pdata = out_data + copy_size / sizeof (T);
65
+ for (int i = 0 ; i < pad_num; ++i) {
66
+ pdata[i] = pad_value;
50
67
}
68
+ out_data += win_size;
51
69
}
52
70
}
53
- out->set_lod (in->lod ());
54
71
}
55
72
};
56
73
0 commit comments