@@ -14,6 +14,7 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/lrn_op.h"
16
16
#include < string>
17
+ #include " paddle/fluid/operators/math/blas.h"
17
18
#ifdef PADDLE_WITH_MKLDNN
18
19
#include " paddle/fluid/platform/mkldnn_helper.h"
19
20
#endif
@@ -29,34 +30,43 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
29
30
const framework::Tensor& input, framework::Tensor* out,
30
31
framework::Tensor* mid, int N, int C, int H, int W, int n,
31
32
T k, T alpha, T beta) {
32
- auto x_v = framework::EigenVector<T>::Flatten (input);
33
-
34
- const int start = -(n - 1 ) / 2 ;
35
- const int end = start + n;
36
-
37
- auto e_mid = framework::EigenTensor<T, 4 >::From (*mid);
38
- e_mid = e_mid.constant (k);
39
-
40
- auto e_x = framework::EigenTensor<T, 4 >::From (input);
41
- for (int m = 0 ; m < N; m++) {
42
- for (int i = 0 ; i < C; i++) {
43
- for (int c = start; c < end; c++) {
44
- int ch = i + c;
45
- if (ch >= 0 && ch < C) {
46
- auto s = e_mid.slice (Eigen::array<int , 4 >({{m, i, 0 , 0 }}),
47
- Eigen::array<int , 4 >({{1 , 1 , H, W}}));
48
-
49
- auto r = e_x.slice (Eigen::array<int , 4 >({{m, ch, 0 , 0 }}),
50
- Eigen::array<int , 4 >({{1 , 1 , H, W}}));
51
-
52
- s += alpha * r.square ();
53
- }
54
- }
33
+ const T* idata = input.data <T>();
34
+ auto place = ctx.GetPlace ();
35
+ auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
36
+ T* odata = out->mutable_data <T>(place);
37
+ T* mdata = mid->mutable_data <T>(place);
38
+ Tensor squared;
39
+ T* sdata = squared.mutable_data <T>({1 , C + n - 1 , H, W}, place);
40
+ std::memset (sdata, 0 , sizeof (T) * squared.numel ());
41
+ for (int i = 0 ; i < mid->numel (); ++i) {
42
+ mdata[i] = k;
43
+ }
44
+ int img_size = H * W;
45
+ int fea_size = C * img_size;
46
+ int pre_pad = (n - 1 ) / 2 ;
47
+ // compute batches one by one
48
+ for (int i = 0 ; i < N; ++i) {
49
+ blas.VSQR (fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
50
+ // init the first channel of mid
51
+ for (int c = 0 ; c < n; ++c) {
52
+ blas.AXPY (img_size, alpha, sdata + c * img_size, mdata + i * fea_size);
53
+ }
54
+ for (int c = 1 ; c < C; ++c) {
55
+ // copy previous scale
56
+ int mid_offset = i * fea_size + c * img_size;
57
+ std::memcpy (mdata + mid_offset, mdata + mid_offset - img_size,
58
+ img_size * sizeof (T));
59
+ // add last
60
+ blas.AXPY (img_size, alpha, sdata + (c + n - 1 ) * img_size,
61
+ mdata + mid_offset);
62
+ // sub rest
63
+ blas.AXPY (img_size, -alpha, sdata + (c - 1 ) * img_size,
64
+ mdata + mid_offset);
55
65
}
56
66
}
57
-
58
- auto out_e = framework::EigenVector<T>:: Flatten (*out );
59
- out_e = x_v * e_mid. reshape (Eigen::DSizes< int , 1 >(e_mid. size ())). pow (-beta );
67
+ // compute the final output
68
+ blas. VPOW (mid-> numel (), mdata, -beta, odata );
69
+ blas. VMUL (mid-> numel (), odata, idata, odata );
60
70
}
61
71
};
62
72
template struct LRNFunctor <platform::CPUDeviceContext, float >;
@@ -156,6 +166,9 @@ class LRNOp : public framework::OperatorWithKernel {
156
166
auto x_dim = ctx->GetInputDim (" X" );
157
167
PADDLE_ENFORCE_EQ (x_dim.size (), 4 , " Input(X)'rank of LRNOp should be 4." );
158
168
169
+ int n = ctx->Attrs ().Get <int >(" n" );
170
+ PADDLE_ENFORCE (n > 0 && n % 2 == 1 , " n should be positive odd value" );
171
+
159
172
ctx->SetOutputDim (" Out" , x_dim);
160
173
ctx->ShareLoD (" X" , /* ->*/ " Out" );
161
174
ctx->SetOutputDim (" MidOut" , x_dim);
0 commit comments