Skip to content

Commit 0385b0a

Browse files
committed
Accelerate SequencePool Op on SUM mode
test=develop
1 parent e1904ac commit 0385b0a

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ add_subdirectory(detail)
33
endif(NOT WIN32)
44

55
function(math_library TARGET)
6-
# math_library is a function to create math library.
7-
# The interface is the same as cc_library.
6+
# math_library is a function to create math library.
7+
# The interface is the same as cc_library.
88
# But it handle split GPU/CPU code and link some common library.
99
set(cc_srcs)
1010
set(cu_srcs)

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/math/sequence_pooling.h"
1615
#include <string>
16+
17+
#include "paddle/fluid/operators/math/blas.h"
1718
#include "paddle/fluid/operators/math/math_function.h"
19+
#include "paddle/fluid/operators/math/sequence_pooling.h"
1820

1921
namespace paddle {
2022
namespace operators {
@@ -180,6 +182,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
180182
}
181183
auto lod = input.lod()[0];
182184
auto& place = *context.eigen_device();
185+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
183186
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
184187
Tensor in_t =
185188
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
@@ -191,7 +194,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
191194
if (pooltype == "AVERAGE") {
192195
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
193196
} else if (pooltype == "SUM") {
194-
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
197+
if (h > 0) {
198+
const T* in_data = in_t.data<T>();
199+
T* out_data = out_t.mutable_data<T>(context.GetPlace());
200+
blas.VCOPY(w, in_data, out_data);
201+
for (int64_t r = 1; r != h; ++r) {
202+
blas.AXPY(w, 1., in_data + r * w, out_data);
203+
}
204+
}
195205
} else if (pooltype == "SQRT") {
196206
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
197207
std::sqrt(static_cast<T>(h));
@@ -223,6 +233,7 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
223233
}
224234
auto lod = in_grad->lod()[0];
225235
auto& place = *context.eigen_device();
236+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
226237
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
227238
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
228239
static_cast<int>(lod[i + 1]));
@@ -237,7 +248,11 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
237248
if (pooltype == "AVERAGE") {
238249
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
239250
} else if (pooltype == "SUM") {
240-
in_g_e.device(place) = (out_g_e).broadcast(bcast);
251+
const T* out_g_data = out_g_t.data<T>();
252+
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
253+
for (int r = 0; r != h; ++r) {
254+
blas.VCOPY(w, out_g_data, in_g_data + r * w);
255+
}
241256
} else if (pooltype == "SQRT") {
242257
in_g_e.device(place) =
243258
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);

0 commit comments

Comments
 (0)