Skip to content

Commit 2468057

Browse files
committed
Move code to SumSeqPoolGradFunctor
test=develop
1 parent 9725db0 commit 2468057

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor {
157157
}
158158
};
159159

160+
template <typename T>
161+
class SumSeqPoolGradFunctor {
162+
public:
163+
void operator()(const platform::CPUDeviceContext& context,
164+
const framework::Tensor& out_grad,
165+
framework::LoDTensor* in_grad) {
166+
auto lod = in_grad->lod()[0];
167+
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
168+
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
169+
PADDLE_ENFORCE(in_w == out_w);
170+
const T* out_g_data = out_grad.data<T>();
171+
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
172+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
173+
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
174+
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
175+
int64_t in_offset = lod[i] * in_w;
176+
const T* out_pos = out_g_data + i * out_w;
177+
T* in_pos = in_g_data + in_offset;
178+
for (int r = 0; r != h; ++r) {
179+
blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
180+
}
181+
}
182+
}
183+
};
184+
160185
template <typename T>
161186
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
162187
public:
@@ -233,23 +258,8 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
233258
}
234259

235260
if (pooltype == "SUM") {
236-
auto lod = in_grad->lod()[0];
237-
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
238-
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
239-
PADDLE_ENFORCE(in_w == out_w);
240-
const T* out_g_data = out_grad.data<T>();
241-
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
242-
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
243-
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
244-
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
245-
int64_t in_offset = lod[i] * in_w;
246-
const T* out_pos = out_g_data + i * out_w;
247-
T* in_pos = in_g_data + in_offset;
248-
for (int r = 0; r != h; ++r) {
249-
blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
250-
}
251-
}
252-
261+
math::SumSeqPoolGradFunctor<T> sum_pool_grad;
262+
sum_pool_grad(context, out_grad, in_grad);
253263
return;
254264
}
255265

0 commit comments

Comments
 (0)