@@ -109,24 +109,32 @@ class LastFirstSeqPoolFunctor {
109
109
void operator ()(const platform::CPUDeviceContext& context,
110
110
const framework::LoDTensor& input, framework::Tensor* output,
111
111
const std::string pooltype) {
112
+ // Create pointers to input and output data
112
113
auto * in_data = input.data <T>();
113
114
auto * out_data = output->data <T>();
115
+
116
+ // Calculate length of each word
114
117
int64_t word_len = input.numel () / input.dims ()[0 ];
115
118
auto lod = input.lod ()[0 ];
116
- auto dims = input.dims ();
117
119
if (pooltype == " LAST" ){
118
120
for (int i=0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i ){
121
+ // Calculate the length of each sequence
119
122
int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
123
+ // Point to the begin of next sequence
120
124
in_data += seq_len* word_len;
121
- std::memcpy (out_data,(in_data-word_len),word_len*sizeof (int ));
125
+ // Copy the last words to output
126
+ std::memcpy (out_data,(in_data-word_len),word_len*sizeof (T));
122
127
out_data += word_len;
123
128
124
129
}
125
130
}
126
131
else if (pooltype == " FIRST" ){
127
132
for (int i=0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i ){
133
+ // Calculate the length of each sequence
128
134
int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
129
- std::memcpy (out_data,in_data,word_len*sizeof (int ));
135
+ // Copy the first words of sequence to output
136
+ std::memcpy (out_data,in_data,word_len*sizeof (T));
137
+ // Point to the next sequence
130
138
in_data += seq_len * word_len;
131
139
out_data += word_len;
132
140
0 commit comments