@@ -104,11 +104,10 @@ class MaxSeqPoolGradFunctor {
104
104
};
105
105
106
106
template <typename T>
107
- class LastFirstSeqPoolFunctor {
107
+ class LastSeqPoolFunctor {
108
108
public:
109
109
void operator ()(const platform::CPUDeviceContext& context,
110
- const framework::LoDTensor& input, framework::Tensor* output,
111
- const std::string pooltype) {
110
+ const framework::LoDTensor& input, framework::Tensor* output) {
112
111
// Create pointers to input and output data
113
112
auto * in_data = input.data <T>();
114
113
auto * out_data = output->data <T>();
@@ -117,29 +116,40 @@ class LastFirstSeqPoolFunctor {
117
116
int64_t item_size = input.numel () / input.dims ()[0 ];
118
117
auto lod = input.lod ()[0 ];
119
118
int seq_num = static_cast <int >(lod.size ()) - 1 ;
120
- if (pooltype == " LAST" ) {
121
- for (int i = 0 ; i < seq_num; ++i) {
122
- // Calculate the length of each sequence
123
- int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
124
- // Point to the begin of next sequence
125
- in_data += seq_len * item_size;
126
- // Copy the last item of sequence to output
127
- std::memcpy (out_data, (in_data - item_size), item_size * sizeof (T));
128
- out_data += item_size;
129
- }
130
- } else if (pooltype == " FIRST" ) {
131
- for (int i = 0 ; i < seq_num; ++i) {
132
- // Calculate the length of each sequence
133
- int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
134
- // Copy the first item of sequence to output
135
- std::memcpy (out_data, in_data, item_size * sizeof (T));
136
- // Point to the next sequence
137
- in_data += seq_len * item_size;
138
- out_data += item_size;
119
+ for (int i = 0 ; i < seq_num; ++i) {
120
+ // Calculate the length of each sequence
121
+ int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
122
+ // Point to the begin of next sequence
123
+ in_data += seq_len * item_size;
124
+ // Copy the last item of sequence to output
125
+ std::memcpy (out_data, (in_data - item_size), item_size * sizeof (T));
126
+ out_data += item_size;
127
+ }
128
+ }
129
+ };
130
+
131
+ template <typename T>
132
+ class FirstSeqPoolFunctor {
133
+ public:
134
+ void operator ()(const platform::CPUDeviceContext& context,
135
+ const framework::LoDTensor& input, framework::Tensor* output) {
136
+ // Create pointers to input and output data
137
+ auto * in_data = input.data <T>();
138
+ auto * out_data = output->data <T>();
139
+
140
+ // Calculate the size of each item in sequence
141
+ int64_t item_size = input.numel () / input.dims ()[0 ];
142
+ auto lod = input.lod ()[0 ];
143
+ int seq_num = static_cast <int >(lod.size ()) - 1 ;
144
+ for (int i = 0 ; i < seq_num; ++i) {
145
+ // Calculate the length of each sequence
146
+ int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
147
+ // Copy the first item of sequence to output
148
+ std::memcpy (out_data, in_data, item_size * sizeof (T));
149
+ // Point to the next sequence
150
+ in_data += seq_len * item_size;
151
+ out_data += item_size;
139
152
}
140
- } else {
141
- PADDLE_THROW (" it's not LAST or FIRST pool type" );
142
- }
143
153
}
144
154
};
145
155
@@ -156,11 +166,17 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
156
166
max_pool (context, input, output, index);
157
167
return ;
158
168
}
159
- if (pooltype == " LAST" || pooltype == " FIRST " ) {
160
- math::LastFirstSeqPoolFunctor <T> lastfirst_pool ;
161
- lastfirst_pool (context, input, output, pooltype );
169
+ if (pooltype == " LAST" ) {
170
+ math::LastSeqPoolFunctor <T> last_pool ;
171
+ last_pool (context, input, output);
162
172
return ;
163
173
}
174
+ if (pooltype == " FIRST" ) {
175
+ math::FirstSeqPoolFunctor<T> first_pool;
176
+ first_pool (context, input, output);
177
+ return ;
178
+ }
179
+
164
180
165
181
auto lod = input.lod ()[0 ];
166
182
auto & place = *context.eigen_device ();
0 commit comments