Skip to content

Commit 53185fd

Browse files
committed
Rewrite sequence pooling last and first mode with memcpy and clean code
1 parent a557608 commit 53185fd

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,39 @@ class MaxSeqPoolGradFunctor {
103103
}
104104
};
105105

106+
template <typename T>
107+
class LastFirstSeqPoolFunctor {
108+
public:
109+
void operator()(const platform::CPUDeviceContext& context,
110+
const framework::LoDTensor& input, framework::Tensor* output,
111+
const std::string pooltype) {
112+
auto* in_data = input.data<T>();
113+
auto* out_data = output->data<T>();
114+
int64_t word_len = input.numel() / input.dims()[0];
115+
auto lod = input.lod()[0];
116+
auto dims = input.dims();
117+
if (pooltype == "LAST"){
118+
for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){
119+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
120+
in_data += seq_len* word_len;
121+
std::memcpy(out_data,(in_data-word_len),word_len*sizeof(int));
122+
out_data += word_len;
123+
124+
}
125+
}
126+
else if(pooltype == "FIRST"){
127+
for (int i=0; i < static_cast<int>(lod.size()) - 1; ++i ){
128+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
129+
std::memcpy(out_data,in_data,word_len*sizeof(int));
130+
in_data += seq_len * word_len;
131+
out_data += word_len;
132+
133+
}
134+
135+
}
136+
}
137+
};
138+
106139
template <typename T>
107140
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
108141
public:
@@ -116,6 +149,12 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
116149
max_pool(context, input, output, index);
117150
return;
118151
}
152+
if (pooltype == "LAST" || pooltype == "FIRST") {
153+
math::LastFirstSeqPoolFunctor<T> lastfirst_pool;
154+
lastfirst_pool(context, input, output, pooltype);
155+
return;
156+
}
157+
119158
auto lod = input.lod()[0];
120159
auto& place = *context.eigen_device();
121160
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
@@ -133,10 +172,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
133172
} else if (pooltype == "SQRT") {
134173
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
135174
std::sqrt(static_cast<T>(h));
136-
} else if (pooltype == "LAST") {
137-
out_e.device(place) = in_e.chip(h - 1, 0);
138-
} else if (pooltype == "FIRST") {
139-
out_e.device(place) = in_e.chip(0, 0);
140175
} else {
141176
PADDLE_THROW("unsupported pooling pooltype");
142177
}

0 commit comments

Comments
 (0)