@@ -109,40 +109,38 @@ class LastFirstSeqPoolFunctor {
109
109
void operator ()(const platform::CPUDeviceContext& context,
110
110
const framework::LoDTensor& input, framework::Tensor* output,
111
111
const std::string pooltype) {
112
- // Create pointers to input and output data
113
- auto * in_data = input.data <T>();
114
- auto * out_data = output->data <T>();
115
-
116
- // Calculate the size of each item in sequence
117
- int64_t item_size = input.numel () / input.dims ()[0 ];
118
- auto lod = input.lod ()[0 ];
119
- int seq_num = static_cast <int >(lod.size ()) - 1 ;
120
- if (pooltype == " LAST" ){
121
- for (int i=0 ; i < seq_num; ++i ){
122
- // Calculate the length of each sequence
123
- int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
124
- // Point to the begin of next sequence
125
- in_data += seq_len* item_size;
126
- // Copy the last item to output
127
- std::memcpy (out_data,(in_data-item_size),item_size*sizeof (T));
128
- out_data += item_size;
129
- }
130
- }
131
- else if (pooltype == " FIRST" ){
132
- for (int i=0 ; i < seq_num; ++i ){
133
- // Calculate the length of each sequence
134
- int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
135
- // Copy the first item of sequence to output
136
- std::memcpy (out_data,in_data,item_size*sizeof (T));
137
- // Point to the next sequence
138
- in_data += seq_len * item_size;
139
- out_data += item_size;
140
- }
141
- }
142
- else {
143
- PADDLE_THROW (" it's not LAST or FIRST pool type" );
112
+ // Create pointers to input and output data
113
+ auto * in_data = input.data <T>();
114
+ auto * out_data = output->data <T>();
115
+
116
+ // Calculate the size of each item in sequence
117
+ int64_t item_size = input.numel () / input.dims ()[0 ];
118
+ auto lod = input.lod ()[0 ];
119
+ int seq_num = static_cast <int >(lod.size ()) - 1 ;
120
+ if (pooltype == " LAST" ) {
121
+ for (int i = 0 ; i < seq_num; ++i) {
122
+ // Calculate the length of each sequence
123
+ int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
124
+ // Point to the begin of next sequence
125
+ in_data += seq_len * item_size;
126
+ // Copy the last item of sequence to output
127
+ std::memcpy (out_data, (in_data - item_size), item_size * sizeof (T));
128
+ out_data += item_size;
144
129
}
145
- }
130
+ } else if (pooltype == " FIRST" ) {
131
+ for (int i = 0 ; i < seq_num; ++i) {
132
+ // Calculate the length of each sequence
133
+ int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
134
+ // Copy the first item of sequence to output
135
+ std::memcpy (out_data, in_data, item_size * sizeof (T));
136
+ // Point to the next sequence
137
+ in_data += seq_len * item_size;
138
+ out_data += item_size;
139
+ }
140
+ } else {
141
+ PADDLE_THROW (" it's not LAST or FIRST pool type" );
142
+ }
143
+ }
146
144
};
147
145
148
146
template <typename T>
0 commit comments