Skip to content

Commit 9c68709

Browse files
committed
Accelerate sequence_pool functor
1 parent 14ebc42 commit 9c68709

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,30 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
231231
math::SetConstant<platform::CPUDeviceContext, T> functor;
232232
functor(context, in_grad, 0);
233233
}
234+
235+
if (pooltype == "SUM") {
236+
auto lod = in_grad->lod()[0];
237+
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
238+
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
239+
PADDLE_ENFORCE(in_w == out_w);
240+
const T* out_g_data = out_grad.data<T>();
241+
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
242+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
243+
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
244+
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
245+
int64_t in_offset = lod[i];
246+
const T* out_pos = out_g_data + i * out_w;
247+
T* in_pos = in_g_data + in_offset;
248+
for (int r = 0; r != h; ++r) {
249+
blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
250+
}
251+
}
252+
253+
return;
254+
}
255+
234256
auto lod = in_grad->lod()[0];
235257
auto& place = *context.eigen_device();
236-
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
237258
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
238259
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
239260
static_cast<int>(lod[i + 1]));
@@ -247,12 +268,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
247268

248269
if (pooltype == "AVERAGE") {
249270
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
250-
} else if (pooltype == "SUM") {
251-
const T* out_g_data = out_g_t.data<T>();
252-
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
253-
for (int r = 0; r != h; ++r) {
254-
blas.VCOPY(w, out_g_data, in_g_data + r * w);
255-
}
256271
} else if (pooltype == "SQRT") {
257272
in_g_e.device(place) =
258273
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);

0 commit comments

Comments
 (0)