Skip to content

Commit 9eb0ab1

Browse files
authored
Merge pull request #14384 from tensor-tang/refine/lrn
Refine lrn cpu forward
2 parents e65cbd3 + b4dfba1 commit 9eb0ab1

File tree

5 files changed

+107
-27
lines changed

5 files changed

+107
-27
lines changed

paddle/fluid/operators/lrn_op.cc

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/lrn_op.h"
1616
#include <string>
17+
#include "paddle/fluid/operators/math/blas.h"
1718
#ifdef PADDLE_WITH_MKLDNN
1819
#include "paddle/fluid/platform/mkldnn_helper.h"
1920
#endif
@@ -29,34 +30,43 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
2930
const framework::Tensor& input, framework::Tensor* out,
3031
framework::Tensor* mid, int N, int C, int H, int W, int n,
3132
T k, T alpha, T beta) {
32-
auto x_v = framework::EigenVector<T>::Flatten(input);
33-
34-
const int start = -(n - 1) / 2;
35-
const int end = start + n;
36-
37-
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
38-
e_mid = e_mid.constant(k);
39-
40-
auto e_x = framework::EigenTensor<T, 4>::From(input);
41-
for (int m = 0; m < N; m++) {
42-
for (int i = 0; i < C; i++) {
43-
for (int c = start; c < end; c++) {
44-
int ch = i + c;
45-
if (ch >= 0 && ch < C) {
46-
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
47-
Eigen::array<int, 4>({{1, 1, H, W}}));
48-
49-
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
50-
Eigen::array<int, 4>({{1, 1, H, W}}));
51-
52-
s += alpha * r.square();
53-
}
54-
}
33+
const T* idata = input.data<T>();
34+
auto place = ctx.GetPlace();
35+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
36+
T* odata = out->mutable_data<T>(place);
37+
T* mdata = mid->mutable_data<T>(place);
38+
Tensor squared;
39+
T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
40+
std::memset(sdata, 0, sizeof(T) * squared.numel());
41+
for (int i = 0; i < mid->numel(); ++i) {
42+
mdata[i] = k;
43+
}
44+
int img_size = H * W;
45+
int fea_size = C * img_size;
46+
int pre_pad = (n - 1) / 2;
47+
// compute batches one by one
48+
for (int i = 0; i < N; ++i) {
49+
blas.VSQR(fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
50+
// init the first channel of mid
51+
for (int c = 0; c < n; ++c) {
52+
blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size);
53+
}
54+
for (int c = 1; c < C; ++c) {
55+
// copy previous scale
56+
int mid_offset = i * fea_size + c * img_size;
57+
std::memcpy(mdata + mid_offset, mdata + mid_offset - img_size,
58+
img_size * sizeof(T));
59+
// add last
60+
blas.AXPY(img_size, alpha, sdata + (c + n - 1) * img_size,
61+
mdata + mid_offset);
62+
// sub rest
63+
blas.AXPY(img_size, -alpha, sdata + (c - 1) * img_size,
64+
mdata + mid_offset);
5565
}
5666
}
57-
58-
auto out_e = framework::EigenVector<T>::Flatten(*out);
59-
out_e = x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
67+
// compute the final output
68+
blas.VPOW(mid->numel(), mdata, -beta, odata);
69+
blas.VMUL(mid->numel(), odata, idata, odata);
6070
}
6171
};
6272
template struct LRNFunctor<platform::CPUDeviceContext, float>;
@@ -156,6 +166,9 @@ class LRNOp : public framework::OperatorWithKernel {
156166
auto x_dim = ctx->GetInputDim("X");
157167
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
158168

169+
int n = ctx->Attrs().Get<int>("n");
170+
PADDLE_ENFORCE(n > 0 && n % 2 == 1, "n should be positive odd value");
171+
159172
ctx->SetOutputDim("Out", x_dim);
160173
ctx->ShareLoD("X", /*->*/ "Out");
161174
ctx->SetOutputDim("MidOut", x_dim);

paddle/fluid/operators/lrn_op.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ class LRNKernel : public framework::OpKernel<T> {
6060
T beta = ctx.Attr<float>("beta");
6161
T k = ctx.Attr<float>("k");
6262

63-
PADDLE_ENFORCE(n > 0, "n should >= 0");
6463
PADDLE_ENFORCE(alpha >= 0.0, "alpha should >= 0.0");
6564
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
6665
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");

paddle/fluid/operators/math/blas.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ class Blas {
152152
template <typename T>
153153
void VEXP(int n, const T* x, T* y) const;
154154

155+
template <typename T>
156+
void VSQR(int n, const T* x, T* y) const;
157+
158+
template <typename T>
159+
void VPOW(int n, const T* x, T alpha, T* y) const;
160+
155161
template <typename T>
156162
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
157163
T* C) const;
@@ -238,6 +244,16 @@ class BlasT : private Blas<DeviceContext> {
238244
Base()->template VEXP<T>(args...);
239245
}
240246

247+
template <typename... ARGS>
248+
void VSQR(ARGS... args) const {
249+
Base()->template VSQR<T>(args...);
250+
}
251+
252+
template <typename... ARGS>
253+
void VPOW(ARGS... args) const {
254+
Base()->template VPOW<T>(args...);
255+
}
256+
241257
template <typename... ARGS>
242258
void GEMV(ARGS... args) const {
243259
Base()->template GEMV<T>(args...);

paddle/fluid/operators/math/blas_impl.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#pragma once
15+
#include <cmath>
1516
#include <limits>
1617
#include <vector>
1718
#include "paddle/fluid/operators/math/math_function.h"
@@ -102,6 +103,16 @@ struct CBlas<float> {
102103
static void VEXP(ARGS... args) {
103104
platform::dynload::vsExp(args...);
104105
}
106+
107+
template <typename... ARGS>
108+
static void VSQR(ARGS... args) {
109+
platform::dynload::vsSqr(args...);
110+
}
111+
112+
template <typename... ARGS>
113+
static void VPOW(ARGS... args) {
114+
platform::dynload::vsPowx(args...);
115+
}
105116
};
106117

107118
template <>
@@ -182,6 +193,16 @@ struct CBlas<double> {
182193
static void VEXP(ARGS... args) {
183194
platform::dynload::vdExp(args...);
184195
}
196+
197+
template <typename... ARGS>
198+
static void VSQR(ARGS... args) {
199+
platform::dynload::vdSqr(args...);
200+
}
201+
202+
template <typename... ARGS>
203+
static void VPOW(ARGS... args) {
204+
platform::dynload::vdPowx(args...);
205+
}
185206
};
186207

187208
#else
@@ -241,6 +262,8 @@ struct CBlas<platform::float16> {
241262
}
242263
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
243264
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
265+
static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); }
266+
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
244267
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
245268
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
246269
#ifdef PADDLE_WITH_MKLML
@@ -398,6 +421,31 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
398421
#endif
399422
}
400423

424+
template <>
425+
template <typename T>
426+
void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const {
427+
#ifdef PADDLE_WITH_MKLML
428+
CBlas<T>::VSQR(n, x, y);
429+
#else
430+
for (int i = 0; i < n; ++i) {
431+
y[i] = std::sqrt(x[i]);
432+
}
433+
#endif
434+
}
435+
436+
template <>
437+
template <typename T>
438+
void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a,
439+
T *y) const {
440+
#ifdef PADDLE_WITH_MKLML
441+
CBlas<T>::VPOW(n, x, a, y);
442+
#else
443+
for (int i = 0; i < n; ++i) {
444+
y[i] = std::pow(x[i], a);
445+
}
446+
#endif
447+
}
448+
401449
template <>
402450
template <typename T>
403451
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {

paddle/fluid/platform/dynload/mklml.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ extern void* mklml_dso_handle;
7676
__macro(vdMul); \
7777
__macro(vsExp); \
7878
__macro(vdExp); \
79+
__macro(vsSqr); \
80+
__macro(vdSqr); \
81+
__macro(vsPowx); \
82+
__macro(vdPowx); \
7983
__macro(MKL_Set_Num_Threads)
8084

8185
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);

0 commit comments

Comments
 (0)