Skip to content

Commit c7e739f

Browse files
authored
Add LRN efficient GPU implement. (#5894)
Add LRN efficient GPU implement
1 parent 1d1555e commit c7e739f

File tree

4 files changed

+289
-93
lines changed

4 files changed

+289
-93
lines changed

paddle/operators/lrn_op.cc

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,103 @@ namespace operators {
1919

2020
using framework::Tensor;
2121

22+
template <typename T>
23+
struct LRNFunctor<platform::CPUPlace, T> {
24+
void operator()(const framework::ExecutionContext& ctx,
25+
const framework::Tensor& input, framework::Tensor* out,
26+
framework::Tensor* mid, int N, int C, int H, int W, int n,
27+
T k, T alpha, T beta) {
28+
auto x_v = framework::EigenVector<T>::Flatten(input);
29+
30+
const int start = -(n - 1) / 2;
31+
const int end = start + n;
32+
33+
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
34+
e_mid = e_mid.constant(k);
35+
36+
auto e_x = framework::EigenTensor<T, 4>::From(input);
37+
for (int m = 0; m < N; m++) {
38+
for (int i = 0; i < C; i++) {
39+
for (int c = start; c <= end; c++) {
40+
int ch = i + c;
41+
if (ch >= 0 && ch < C) {
42+
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
43+
Eigen::array<int, 4>({{1, 1, H, W}}));
44+
45+
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
46+
Eigen::array<int, 4>({{1, 1, H, W}}));
47+
48+
s += alpha * r.square();
49+
}
50+
}
51+
}
52+
}
53+
54+
auto out_e = framework::EigenVector<T>::Flatten(*out);
55+
out_e = x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
56+
}
57+
};
58+
template struct LRNFunctor<platform::CPUPlace, float>;
59+
template struct LRNFunctor<platform::CPUPlace, double>;
60+
61+
template <typename T>
62+
struct LRNGradFunctor<platform::CPUPlace, T> {
63+
void operator()(const framework::ExecutionContext& ctx,
64+
const framework::Tensor& x, const framework::Tensor& out,
65+
const framework::Tensor& mid, framework::Tensor* x_g,
66+
const framework::Tensor& out_g, int N, int C, int H, int W,
67+
int n, T alpha, T beta) {
68+
T ratio = -2 * alpha * beta;
69+
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
70+
x_g_e = x_g_e.constant(0.0);
71+
72+
auto e_x = framework::EigenTensor<T, 4>::From(x);
73+
auto e_x_g = framework::EigenTensor<T, 4>::From(*x_g);
74+
auto e_out = framework::EigenTensor<T, 4>::From(out);
75+
auto e_out_g = framework::EigenTensor<T, 4>::From(out_g);
76+
auto e_mid = framework::EigenTensor<T, 4>::From(mid);
77+
78+
const int start = -(n - 1) / 2;
79+
const int end = start + n;
80+
for (int m = 0; m < N; m++) {
81+
for (int i = 0; i < C; i++) {
82+
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
83+
Eigen::array<int, 4>({{1, 1, H, W}}));
84+
85+
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
86+
Eigen::array<int, 4>({{1, 1, H, W}}));
87+
88+
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
89+
Eigen::array<int, 4>({{1, 1, H, W}}));
90+
91+
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
92+
Eigen::array<int, 4>({{1, 1, H, W}}));
93+
94+
i_x_g = i_mid.pow(-beta) * i_out_g;
95+
for (int c = start; c <= end; c++) {
96+
int ch = i + c;
97+
if (ch < 0 || ch >= C) {
98+
continue;
99+
}
100+
101+
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
102+
Eigen::array<int, 4>({{1, 1, H, W}}));
103+
104+
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
105+
Eigen::array<int, 4>({{1, 1, H, W}}));
106+
107+
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
108+
Eigen::array<int, 4>({{1, 1, H, W}}));
109+
110+
i_x_g += ratio * c_out_g * c_out * i_x / c_mid;
111+
}
112+
}
113+
}
114+
}
115+
};
116+
template struct LRNGradFunctor<platform::CPUPlace, float>;
117+
template struct LRNGradFunctor<platform::CPUPlace, double>;
118+
22119
class LRNOp : public framework::OperatorWithKernel {
23120
public:
24121
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -83,8 +180,8 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
83180
AddComment(R"DOC(
84181
Local Response Normalization Operator.
85182
86-
This operator comes from the paper
87-
"ImageNet Classification with Deep Convolutional Neural Networks".
183+
This operator comes from the paper:
184+
<<ImageNet Classification with Deep Convolutional Neural Networks>>.
88185
89186
The original formula is:
90187
@@ -119,8 +216,7 @@ class LRNOpGrad : public framework::OperatorWithKernel {
119216
protected:
120217
void InferShape(framework::InferShapeContext* ctx) const override {
121218
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
122-
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("MidOut")),
123-
"Input(MidOut@GRAD) should not be null");
219+
PADDLE_ENFORCE(ctx->HasInput("MidOut"), "Input(MidOut) should not be null");
124220
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
125221
"Input(Out@GRAD) should not be null");
126222

paddle/operators/lrn_op.cu

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,167 @@
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#define EIGEN_USE_GPU
1615
#include "paddle/operators/lrn_op.h"
1716

18-
namespace ops = paddle::operators;
17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename T>
21+
__global__ void KeCMRNormFillScale(int img_size, const T* in, T* mid, int C,
22+
int H, int W, int size, T k, T alpha) {
23+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
24+
if (idx < img_size) {
25+
const int w = idx % W;
26+
const int h = (idx / W) % H;
27+
const int n = idx / W / H;
28+
const int offset = (n * C * H + h) * W + w;
29+
30+
in += offset;
31+
mid += offset;
32+
const int step = H * W;
33+
const int pre_pad = (size - 1) / 2;
34+
const int post_pad = size - pre_pad - 1;
35+
36+
T accum = 0;
37+
int index = 0;
38+
while (index < C + post_pad) {
39+
if (index < C) {
40+
T val = in[index * step];
41+
accum += val * val;
42+
}
43+
if (index >= size) {
44+
T val = in[(index - size) * step];
45+
accum -= val * val;
46+
}
47+
if (index >= post_pad) {
48+
mid[(index - post_pad) * step] = k + accum * alpha;
49+
}
50+
++index;
51+
}
52+
}
53+
}
54+
55+
template <typename T>
56+
__global__ void KeCMRNormOutput(int input_size, const T* in, const T* mid,
57+
T negative_beta, T* out) {
58+
const int index = threadIdx.x + blockIdx.x * blockDim.x;
59+
if (index < input_size) {
60+
out[index] = in[index] * pow(mid[index], negative_beta);
61+
}
62+
}
63+
64+
template <typename T>
65+
void CrossMapNormal(const framework::ExecutionContext& ctx, const T* inputs,
66+
T* outputs, T* mid, int N, int C, int H, int W, int n, T k,
67+
T alpha, T beta) {
68+
int img_size = N * H * W;
69+
const int block_size = 1024;
70+
int grid_size = (img_size + block_size - 1) / block_size;
71+
72+
KeCMRNormFillScale<
73+
T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>(
74+
img_size, inputs, mid, C, H, W, n, k, alpha);
75+
76+
int input_size = N * H * W * C;
77+
grid_size = (input_size + block_size - 1) / block_size;
78+
KeCMRNormOutput<
79+
T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>(
80+
input_size, inputs, mid, -beta, outputs);
81+
}
82+
83+
template <typename T>
84+
struct LRNFunctor<platform::GPUPlace, T> {
85+
void operator()(const framework::ExecutionContext& ctx,
86+
const framework::Tensor& input, framework::Tensor* out,
87+
framework::Tensor* mid, int N, int C, int H, int W, int n,
88+
T k, T alpha, T beta) {
89+
CrossMapNormal<T>(
90+
ctx, input.data<T>(), out->mutable_data<T>(ctx.GetPlace()),
91+
mid->mutable_data<T>(ctx.GetPlace()), N, C, H, W, n, k, alpha, beta);
92+
}
93+
};
94+
95+
template struct LRNFunctor<platform::GPUPlace, float>;
96+
template struct LRNFunctor<platform::GPUPlace, double>;
1997

98+
template <typename T>
99+
__global__ void KeCMRNormDiff(int img_size, const T* x, const T* out,
100+
const T* mid, T* x_g, const T* out_g, int C,
101+
int H, int W, int size, T negative_beta,
102+
T ratio) {
103+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
104+
if (idx < img_size) {
105+
const int w = idx % W;
106+
const int h = (idx / W) % H;
107+
const int n = idx / W / H;
108+
const int offset = (n * C * H + h) * W + w;
109+
x += offset;
110+
out += offset;
111+
mid += offset;
112+
out_g += offset;
113+
x_g += offset;
114+
115+
const int step = H * W;
116+
const int pre_pad = size - (size + 1) / 2;
117+
const int post_pad = size - pre_pad - 1;
118+
119+
int index = 0;
120+
T accum = 0;
121+
// TODO(gongwb): optimize this with thread shared array.
122+
while (index < C + post_pad) {
123+
if (index < C) {
124+
x_g[index * step] = 0.0;
125+
accum += out_g[index * step] * out[index * step] / mid[index * step];
126+
}
127+
if (index >= size) {
128+
accum -= out_g[(index - size) * step] * out[(index - size) * step] /
129+
mid[(index - size) * step];
130+
}
131+
if (index >= post_pad) {
132+
x_g[(index - post_pad) * step] +=
133+
out_g[(index - post_pad) * step] *
134+
pow(mid[(index - post_pad) * step], negative_beta) -
135+
ratio * x[(index - post_pad) * step] * accum;
136+
}
137+
++index;
138+
}
139+
}
140+
}
141+
142+
template <typename T>
143+
void CrossMapNormalGrad(const framework::ExecutionContext& ctx, const T* x,
144+
const T* out, const T* mid, T* x_g, const T* out_g,
145+
int N, int C, int H, int W, int n, T alpha, T beta) {
146+
int img_size = N * H * W;
147+
148+
const int block_size = 1024;
149+
int grid_size = (img_size + block_size - 1) / block_size;
150+
151+
KeCMRNormDiff<
152+
T><<<grid_size, block_size, 0, ctx.cuda_device_context().stream()>>>(
153+
img_size, x, out, mid, x_g, out_g, C, H, W, n, -beta,
154+
2.0f * alpha * beta);
155+
}
156+
157+
template <typename T>
158+
struct LRNGradFunctor<platform::GPUPlace, T> {
159+
void operator()(const framework::ExecutionContext& ctx,
160+
const framework::Tensor& x, const framework::Tensor& out,
161+
const framework::Tensor& mid, framework::Tensor* x_g,
162+
const framework::Tensor& out_g, int N, int C, int H, int W,
163+
int n, T alpha, T beta) {
164+
CrossMapNormalGrad<T>(ctx, x.data<T>(), out.data<T>(), mid.data<T>(),
165+
x_g->mutable_data<T>(ctx.GetPlace()), out_g.data<T>(),
166+
N, C, H, W, n, alpha, beta);
167+
}
168+
};
169+
170+
template struct LRNGradFunctor<platform::GPUPlace, float>;
171+
template struct LRNGradFunctor<platform::GPUPlace, double>;
172+
} // namespace operators
173+
} // namespace paddle
174+
175+
namespace ops = paddle::operators;
20176
REGISTER_OP_GPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::GPUPlace, float>);
21177
REGISTER_OP_GPU_KERNEL(lrn_grad,
22178
ops::LRNGradKernel<paddle::platform::GPUPlace, float>);

0 commit comments

Comments
 (0)