@@ -103,6 +103,39 @@ class MaxSeqPoolGradFunctor {
103
103
}
104
104
};
105
105
106
+ template <typename T>
107
+ class LastFirstSeqPoolFunctor {
108
+ public:
109
+ void operator ()(const platform::CPUDeviceContext& context,
110
+ const framework::LoDTensor& input, framework::Tensor* output,
111
+ const std::string pooltype) {
112
+ auto * in_data = input.data <T>();
113
+ auto * out_data = output->data <T>();
114
+ int64_t word_len = input.numel () / input.dims ()[0 ];
115
+ auto lod = input.lod ()[0 ];
116
+ auto dims = input.dims ();
117
+ if (pooltype == " LAST" ){
118
+ for (int i=0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i ){
119
+ int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
120
+ in_data += seq_len* word_len;
121
+ std::memcpy (out_data,(in_data-word_len),word_len*sizeof (int ));
122
+ out_data += word_len;
123
+
124
+ }
125
+ }
126
+ else if (pooltype == " FIRST" ){
127
+ for (int i=0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i ){
128
+ int64_t seq_len = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
129
+ std::memcpy (out_data,in_data,word_len*sizeof (int ));
130
+ in_data += seq_len * word_len;
131
+ out_data += word_len;
132
+
133
+ }
134
+
135
+ }
136
+ }
137
+ };
138
+
106
139
template <typename T>
107
140
class SequencePoolFunctor <platform::CPUDeviceContext, T> {
108
141
public:
@@ -116,6 +149,12 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
116
149
max_pool (context, input, output, index);
117
150
return ;
118
151
}
152
+ if (pooltype == " LAST" || pooltype == " FIRST" ) {
153
+ math::LastFirstSeqPoolFunctor<T> lastfirst_pool;
154
+ lastfirst_pool (context, input, output, pooltype);
155
+ return ;
156
+ }
157
+
119
158
auto lod = input.lod ()[0 ];
120
159
auto & place = *context.eigen_device ();
121
160
for (int i = 0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i) {
@@ -133,10 +172,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
133
172
} else if (pooltype == " SQRT" ) {
134
173
out_e.device (place) = in_e.sum (Eigen::array<int , 1 >({{0 }})) /
135
174
std::sqrt (static_cast <T>(h));
136
- } else if (pooltype == " LAST" ) {
137
- out_e.device (place) = in_e.chip (h - 1 , 0 );
138
- } else if (pooltype == " FIRST" ) {
139
- out_e.device (place) = in_e.chip (0 , 0 );
140
175
} else {
141
176
PADDLE_THROW (" unsupported pooling pooltype" );
142
177
}
0 commit comments