@@ -231,9 +231,30 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
231
231
math::SetConstant<platform::CPUDeviceContext, T> functor;
232
232
functor (context, in_grad, 0 );
233
233
}
234
+
235
+ if (pooltype == " SUM" ) {
236
+ auto lod = in_grad->lod ()[0 ];
237
+ int64_t out_w = out_grad.numel () / out_grad.dims ()[0 ];
238
+ int64_t in_w = in_grad->numel () / in_grad->dims ()[0 ];
239
+ PADDLE_ENFORCE (in_w == out_w);
240
+ const T* out_g_data = out_grad.data <T>();
241
+ T* in_g_data = in_grad->mutable_data <T>(context.GetPlace ());
242
+ auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
243
+ for (int i = 0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i) {
244
+ int64_t h = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
245
+ int64_t in_offset = lod[i];
246
+ const T* out_pos = out_g_data + i * out_w;
247
+ T* in_pos = in_g_data + in_offset;
248
+ for (int r = 0 ; r != h; ++r) {
249
+ blas.VCOPY (in_w, out_pos, in_pos + r * in_w);
250
+ }
251
+ }
252
+
253
+ return ;
254
+ }
255
+
234
256
auto lod = in_grad->lod ()[0 ];
235
257
auto & place = *context.eigen_device ();
236
- auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
237
258
for (int i = 0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i) {
238
259
auto in_g_t = in_grad->Slice (static_cast <int >(lod[i]),
239
260
static_cast <int >(lod[i + 1 ]));
@@ -247,12 +268,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
247
268
248
269
if (pooltype == " AVERAGE" ) {
249
270
in_g_e.device (place) = (out_g_e / static_cast <T>(h)).broadcast (bcast);
250
- } else if (pooltype == " SUM" ) {
251
- const T* out_g_data = out_g_t .data <T>();
252
- T* in_g_data = in_g_t .mutable_data <T>(context.GetPlace ());
253
- for (int r = 0 ; r != h; ++r) {
254
- blas.VCOPY (w, out_g_data, in_g_data + r * w);
255
- }
256
271
} else if (pooltype == " SQRT" ) {
257
272
in_g_e.device (place) =
258
273
(out_g_e / std::sqrt (static_cast <T>(h))).broadcast (bcast);
0 commit comments