@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " paddle/fluid/operators/math/sequence_pooling.h"
16
15
#include < string>
16
+
17
+ #include " paddle/fluid/operators/math/blas.h"
17
18
#include " paddle/fluid/operators/math/math_function.h"
19
+ #include " paddle/fluid/operators/math/sequence_pooling.h"
18
20
19
21
namespace paddle {
20
22
namespace operators {
@@ -180,6 +182,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
180
182
}
181
183
auto lod = input.lod ()[0 ];
182
184
auto & place = *context.eigen_device ();
185
+ auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
183
186
for (int i = 0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i) {
184
187
Tensor in_t =
185
188
input.Slice (static_cast <int >(lod[i]), static_cast <int >(lod[i + 1 ]));
@@ -191,7 +194,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
191
194
if (pooltype == " AVERAGE" ) {
192
195
out_e.device (place) = in_e.mean (Eigen::array<int , 1 >({{0 }}));
193
196
} else if (pooltype == " SUM" ) {
194
- out_e.device (place) = in_e.sum (Eigen::array<int , 1 >({{0 }}));
197
+ if (h > 0 ) {
198
+ const T* in_data = in_t .data <T>();
199
+ T* out_data = out_t .mutable_data <T>(context.GetPlace ());
200
+ blas.VCOPY (w, in_data, out_data);
201
+ for (int64_t r = 1 ; r != h; ++r) {
202
+ blas.AXPY (w, 1 ., in_data + r * w, out_data);
203
+ }
204
+ }
195
205
} else if (pooltype == " SQRT" ) {
196
206
out_e.device (place) = in_e.sum (Eigen::array<int , 1 >({{0 }})) /
197
207
std::sqrt (static_cast <T>(h));
@@ -223,6 +233,7 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
223
233
}
224
234
auto lod = in_grad->lod ()[0 ];
225
235
auto & place = *context.eigen_device ();
236
+ auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
226
237
for (int i = 0 ; i < static_cast <int >(lod.size ()) - 1 ; ++i) {
227
238
auto in_g_t = in_grad->Slice (static_cast <int >(lod[i]),
228
239
static_cast <int >(lod[i + 1 ]));
@@ -237,7 +248,11 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
237
248
if (pooltype == " AVERAGE" ) {
238
249
in_g_e.device (place) = (out_g_e / static_cast <T>(h)).broadcast (bcast);
239
250
} else if (pooltype == " SUM" ) {
240
- in_g_e.device (place) = (out_g_e).broadcast (bcast);
251
+ const T* out_g_data = out_g_t .data <T>();
252
+ T* in_g_data = in_g_t .mutable_data <T>(context.GetPlace ());
253
+ for (int r = 0 ; r != h; ++r) {
254
+ blas.VCOPY (w, out_g_data, in_g_data + r * w);
255
+ }
241
256
} else if (pooltype == " SQRT" ) {
242
257
in_g_e.device (place) =
243
258
(out_g_e / std::sqrt (static_cast <T>(h))).broadcast (bcast);
0 commit comments