@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor {
157
157
}
158
158
};
159
159
160
+ template <typename T>
161
+ class SumSeqPoolGradFunctor {
162
+ public:
163
+ void operator ()(const platform::CPUDeviceContext& context,
164
+ const framework::Tensor& out_grad,
165
+ framework::LoDTensor* in_grad) {
166
+ auto lod = in_grad->lod ()[0 ];
167
+ int64_t out_w = out_grad.numel () / out_grad.dims ()[0 ];
168
+ int64_t in_w = in_grad->numel () / in_grad->dims ()[0 ];
169
+ PADDLE_ENFORCE (in_w == out_w);
170
+ const T* out_g_data = out_grad.data <T>();
171
+ T* in_g_data = in_grad->mutable_data <T>(context.GetPlace ());
172
+ auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
173
+ for (int i = 0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i) {
174
+ int64_t h = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
175
+ int64_t in_offset = lod[i] * in_w;
176
+ const T* out_pos = out_g_data + i * out_w;
177
+ T* in_pos = in_g_data + in_offset;
178
+ for (int r = 0 ; r != h; ++r) {
179
+ blas.VCOPY (in_w, out_pos, in_pos + r * in_w);
180
+ }
181
+ }
182
+ }
183
+ };
184
+
160
185
template <typename T>
161
186
class SequencePoolFunctor <platform::CPUDeviceContext, T> {
162
187
public:
@@ -233,23 +258,8 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
233
258
}
234
259
235
260
if (pooltype == " SUM" ) {
236
- auto lod = in_grad->lod ()[0 ];
237
- int64_t out_w = out_grad.numel () / out_grad.dims ()[0 ];
238
- int64_t in_w = in_grad->numel () / in_grad->dims ()[0 ];
239
- PADDLE_ENFORCE (in_w == out_w);
240
- const T* out_g_data = out_grad.data <T>();
241
- T* in_g_data = in_grad->mutable_data <T>(context.GetPlace ());
242
- auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
243
- for (int i = 0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i) {
244
- int64_t h = static_cast <int64_t >(lod[i + 1 ] - lod[i]);
245
- int64_t in_offset = lod[i] * in_w;
246
- const T* out_pos = out_g_data + i * out_w;
247
- T* in_pos = in_g_data + in_offset;
248
- for (int r = 0 ; r != h; ++r) {
249
- blas.VCOPY (in_w, out_pos, in_pos + r * in_w);
250
- }
251
- }
252
-
261
+ math::SumSeqPoolGradFunctor<T> sum_pool_grad;
262
+ sum_pool_grad (context, out_grad, in_grad);
253
263
return ;
254
264
}
255
265
0 commit comments