Skip to content

Commit cab2982

Browse files
authored
Merge pull request #13829 from velconia/accelerate_sequence_pool_op
Accelerate SequencePool Op on SUM mode of CPU
2 parents c284237 + 0385b0a commit cab2982

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/math/sequence_pooling.h"
1615
#include <string>
16+
17+
#include "paddle/fluid/operators/math/blas.h"
1718
#include "paddle/fluid/operators/math/math_function.h"
19+
#include "paddle/fluid/operators/math/sequence_pooling.h"
1820

1921
namespace paddle {
2022
namespace operators {
@@ -180,6 +182,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
180182
}
181183
auto lod = input.lod()[0];
182184
auto& place = *context.eigen_device();
185+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
183186
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
184187
Tensor in_t =
185188
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
@@ -191,7 +194,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
191194
if (pooltype == "AVERAGE") {
192195
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
193196
} else if (pooltype == "SUM") {
194-
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
197+
if (h > 0) {
198+
const T* in_data = in_t.data<T>();
199+
T* out_data = out_t.mutable_data<T>(context.GetPlace());
200+
blas.VCOPY(w, in_data, out_data);
201+
for (int64_t r = 1; r != h; ++r) {
202+
blas.AXPY(w, 1., in_data + r * w, out_data);
203+
}
204+
}
195205
} else if (pooltype == "SQRT") {
196206
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
197207
std::sqrt(static_cast<T>(h));
@@ -223,6 +233,7 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
223233
}
224234
auto lod = in_grad->lod()[0];
225235
auto& place = *context.eigen_device();
236+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
226237
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
227238
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
228239
static_cast<int>(lod[i + 1]));
@@ -237,7 +248,11 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
237248
if (pooltype == "AVERAGE") {
238249
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
239250
} else if (pooltype == "SUM") {
240-
in_g_e.device(place) = (out_g_e).broadcast(bcast);
251+
const T* out_g_data = out_g_t.data<T>();
252+
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
253+
for (int r = 0; r != h; ++r) {
254+
blas.VCOPY(w, out_g_data, in_g_data + r * w);
255+
}
241256
} else if (pooltype == "SQRT") {
242257
in_g_e.device(place) =
243258
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);

0 commit comments

Comments
 (0)