@@ -113,34 +113,35 @@ class LastFirstSeqPoolFunctor {
113
113
auto * in_data = input.data <T>();
114
114
auto * out_data = output->data <T>();
115
115
116
- // Calculate length of each word
117
- int64_t word_len = input.numel () / input.dims ()[0 ];
116
+ // Calculate the size of each item in sequence
117
+ int64_t item_size = input.numel () / input.dims ()[0 ];
118
118
auto lod = input.lod ()[0 ];
119
+ int seq_num = static_cast <int >(lod.size ()) - 1 ;
119
120
if (pooltype == " LAST" ){
120
- for (int i=0 ; i < static_cast < int >(lod. size ()) - 1 ; ++i ){
121
+ for (int i=0 ; i < seq_num ; ++i ){
121
122
// Calculate the length of each sequence
122
123
int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
123
124
// Point to the begin of next sequence
124
- in_data += seq_len* word_len;
125
- // Copy the last words to output
126
- std::memcpy (out_data,(in_data-word_len),word_len*sizeof (T));
127
- out_data += word_len;
128
-
125
+ in_data += seq_len* item_size;
126
+ // Copy the last item to output
127
+ std::memcpy (out_data,(in_data-item_size),item_size*sizeof (T));
128
+ out_data += item_size;
129
129
}
130
130
}
131
131
else if (pooltype == " FIRST" ){
132
- for (int i=0 ; i < static_cast < int >(lod. size ()) - 1 ; ++i ){
132
+ for (int i=0 ; i < seq_num ; ++i ){
133
133
// Calculate the length of each sequence
134
134
int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
135
- // Copy the first words of sequence to output
136
- std::memcpy (out_data,in_data,word_len *sizeof (T));
135
+ // Copy the first item of sequence to output
136
+ std::memcpy (out_data,in_data,item_size *sizeof (T));
137
137
// Point to the next sequence
138
- in_data += seq_len * word_len;
139
- out_data += word_len;
140
-
138
+ in_data += seq_len * item_size;
139
+ out_data += item_size;
141
140
}
142
-
143
141
}
142
+ else {
143
+ PADDLE_THROW (" it's not LAST or FIRST pool type" );
144
+ }
144
145
}
145
146
};
146
147
0 commit comments