Skip to content

Commit b4dfba1

Browse files
committed
refine lrn_op cpu forward and speedup
test=develop
1 parent 1be85d0 commit b4dfba1

File tree

2 files changed

+39
-27
lines changed

2 files changed

+39
-27
lines changed

paddle/fluid/operators/lrn_op.cc

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/lrn_op.h"
1616
#include <string>
17+
#include "paddle/fluid/operators/math/blas.h"
1718
#ifdef PADDLE_WITH_MKLDNN
1819
#include "paddle/fluid/platform/mkldnn_helper.h"
1920
#endif
@@ -29,34 +30,43 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
2930
const framework::Tensor& input, framework::Tensor* out,
3031
framework::Tensor* mid, int N, int C, int H, int W, int n,
3132
T k, T alpha, T beta) {
32-
auto x_v = framework::EigenVector<T>::Flatten(input);
33-
34-
const int start = -(n - 1) / 2;
35-
const int end = start + n;
36-
37-
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
38-
e_mid = e_mid.constant(k);
39-
40-
auto e_x = framework::EigenTensor<T, 4>::From(input);
41-
for (int m = 0; m < N; m++) {
42-
for (int i = 0; i < C; i++) {
43-
for (int c = start; c < end; c++) {
44-
int ch = i + c;
45-
if (ch >= 0 && ch < C) {
46-
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
47-
Eigen::array<int, 4>({{1, 1, H, W}}));
48-
49-
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
50-
Eigen::array<int, 4>({{1, 1, H, W}}));
51-
52-
s += alpha * r.square();
53-
}
54-
}
33+
const T* idata = input.data<T>();
34+
auto place = ctx.GetPlace();
35+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
36+
T* odata = out->mutable_data<T>(place);
37+
T* mdata = mid->mutable_data<T>(place);
38+
Tensor squared;
39+
T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
40+
std::memset(sdata, 0, sizeof(T) * squared.numel());
41+
for (int i = 0; i < mid->numel(); ++i) {
42+
mdata[i] = k;
43+
}
44+
int img_size = H * W;
45+
int fea_size = C * img_size;
46+
int pre_pad = (n - 1) / 2;
47+
// compute batches one by one
48+
for (int i = 0; i < N; ++i) {
49+
blas.VSQR(fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
50+
// init the first channel of mid
51+
for (int c = 0; c < n; ++c) {
52+
blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size);
53+
}
54+
for (int c = 1; c < C; ++c) {
55+
// copy previous scale
56+
int mid_offset = i * fea_size + c * img_size;
57+
std::memcpy(mdata + mid_offset, mdata + mid_offset - img_size,
58+
img_size * sizeof(T));
59+
// add last
60+
blas.AXPY(img_size, alpha, sdata + (c + n - 1) * img_size,
61+
mdata + mid_offset);
62+
// sub rest
63+
blas.AXPY(img_size, -alpha, sdata + (c - 1) * img_size,
64+
mdata + mid_offset);
5565
}
5666
}
57-
58-
auto out_e = framework::EigenVector<T>::Flatten(*out);
59-
out_e = x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
67+
// compute the final output
68+
blas.VPOW(mid->numel(), mdata, -beta, odata);
69+
blas.VMUL(mid->numel(), odata, idata, odata);
6070
}
6171
};
6272
template struct LRNFunctor<platform::CPUDeviceContext, float>;
@@ -156,6 +166,9 @@ class LRNOp : public framework::OperatorWithKernel {
156166
auto x_dim = ctx->GetInputDim("X");
157167
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
158168

169+
int n = ctx->Attrs().Get<int>("n");
170+
PADDLE_ENFORCE(n > 0 && n % 2 == 1, "n should be positive odd value");
171+
159172
ctx->SetOutputDim("Out", x_dim);
160173
ctx->ShareLoD("X", /*->*/ "Out");
161174
ctx->SetOutputDim("MidOut", x_dim);

paddle/fluid/operators/lrn_op.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ class LRNKernel : public framework::OpKernel<T> {
6060
T beta = ctx.Attr<float>("beta");
6161
T k = ctx.Attr<float>("k");
6262

63-
PADDLE_ENFORCE(n > 0, "n should >= 0");
6463
PADDLE_ENFORCE(alpha >= 0.0, "alpha should >= 0.0");
6564
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
6665
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");

0 commit comments

Comments
 (0)