Skip to content

Commit f4c869d

Browse files
yihuaxutensor-tang
authored andcommitted
Optimize the layer_norm operator with AVX intrinsic function (#14417)
* Optimize layer_norm operator with AVX intrinsic functions * Revert the wrong modifications * Implement the jit kernel for layer_norm operator * Add math headfile to fix the compile issue (test=develop) * Add math headfile to fix the compile issue (test=develop) * Fixed the intrinsic headfile issue (test=develop) * Fix the conflicts (test=develop) * Revert for CUDA compiler (test=develop) * Fixed the cuda depency (test=develop) * Fix the marco issues (test=develop)
1 parent 816b464 commit f4c869d

File tree

4 files changed

+269
-1
lines changed

4 files changed

+269
-1
lines changed

paddle/fluid/operators/layer_norm_op.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/op_registry.h"
1818
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
1919
#include "paddle/fluid/operators/math/blas.h"
20+
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
21+
!defined(__OSX__)
22+
#include "paddle/fluid/operators/math/jit_kernel.h"
23+
#endif
2024
#include "paddle/fluid/operators/math/math_function.h"
2125

2226
namespace paddle {
@@ -191,6 +195,8 @@ class LayerNormKernel : public framework::OpKernel<T> {
191195
out.ShareDataWith(*y);
192196
out.Resize(matrix_shape);
193197

198+
#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \
199+
defined(__OSX__)
194200
auto& dev_ctx = ctx.template device_context<DeviceContext>();
195201
RowwiseMean2D<DeviceContext, T> row_mean(left, right, ctx.device_context());
196202

@@ -217,6 +223,19 @@ class LayerNormKernel : public framework::OpKernel<T> {
217223
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
218224
ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out);
219225
}
226+
#else
227+
PADDLE_ENFORCE_EQ(mean->numel(), left);
228+
PADDLE_ENFORCE_EQ(var->numel(), left);
229+
PADDLE_ENFORCE_EQ(scale->numel(), right);
230+
PADDLE_ENFORCE_EQ(bias->numel(), right);
231+
232+
const auto& ker = math::jitkernel::KernelPool::Instance()
233+
.template Get<math::jitkernel::LayerNormKernel<T>>(
234+
static_cast<int>(right));
235+
ker->Compute(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
236+
scale->data<T>(), bias->data<T>(), static_cast<int>(left),
237+
static_cast<const float>(epsilon));
238+
#endif
220239
}
221240
};
222241

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ endif()
7777
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
7878
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
7979
if (NOT WIN32)
80-
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc)
80+
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc jit_kernel_layer_norm.cc)
8181
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
8282
if(WITH_XBYAK)
8383
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ class CRFDecodeKernel : public Kernel {
145145
int *track) const = 0;
146146
};
147147

148+
template <typename T>
149+
class LayerNormKernel : public Kernel {
150+
public:
151+
virtual void Compute(T *x, T *out, T *mean, T *var, const T *scale,
152+
const T *bias, int height,
153+
const float epsilon) const = 0;
154+
};
155+
148156
} // namespace jitkernel
149157
} // namespace math
150158
} // namespace operators
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
#include "paddle/fluid/operators/math/jit_kernel.h"
12+
#include <math.h>
13+
#include <limits>
14+
#include <string>
15+
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
16+
#ifdef __AVX__
17+
#include <immintrin.h>
18+
#endif
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace math {
23+
namespace jitkernel {
24+
25+
namespace jit = platform::jit;
26+
27+
/* Layer Norm JitKernel */
28+
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
29+
class LayerNormKernelImpl : public LayerNormKernel<T> {
30+
public:
31+
explicit LayerNormKernelImpl(int right) : LayerNormKernel<T>() {
32+
this->num_ = right;
33+
}
34+
35+
void Compute(T* x, T* out, T* mean, T* var, const T* scale, const T* bias,
36+
int height, const float epsilon) const override {
37+
// get mean
38+
for (int i = 0; i < height; i++) {
39+
T sum = 0.0;
40+
int offset = i * this->num_;
41+
for (int j = 0; j < this->num_; j++) {
42+
sum += x[offset + j];
43+
}
44+
mean[i] = sum / this->num_;
45+
}
46+
47+
// get variance
48+
for (int i = 0; i < height; i++) {
49+
T sum = 0.0;
50+
int offset = i * this->num_;
51+
for (int j = 0; j < this->num_; j++) {
52+
sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]);
53+
}
54+
var[i] = sum / this->num_;
55+
}
56+
57+
for (int i = 0; i < height; i++) {
58+
int offset = i * this->num_;
59+
T sqrt_var = sqrt(var[i] + (T)epsilon);
60+
for (int j = 0; j < this->num_; j++) {
61+
out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var;
62+
}
63+
}
64+
if (scale) {
65+
for (int i = 0; i < height; i++) {
66+
int offset = i * this->num_;
67+
for (int j = 0; j < this->num_; j++) {
68+
out[offset + j] *= scale[j];
69+
}
70+
}
71+
}
72+
73+
if (bias) {
74+
for (int i = 0; i < height; i++) {
75+
int offset = i * this->num_;
76+
for (int j = 0; j < this->num_; j++) {
77+
out[offset + j] += bias[j];
78+
}
79+
}
80+
}
81+
}
82+
};
83+
84+
#define INTRIAVX_FLOAT(isa, block) \
85+
template <> \
86+
LayerNormKernelImpl<float, isa, block>::LayerNormKernelImpl(int right) \
87+
: LayerNormKernel<float>() { \
88+
this->num_ = right; \
89+
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
90+
this->end_ = this->num_ - this->rest_; \
91+
} \
92+
template <> \
93+
void LayerNormKernelImpl<float, jit::avx, block>::Compute( \
94+
float* x, float* out, float* mean, float* var, const float* scale, \
95+
const float* bias, int height, const float epsilon) const { \
96+
__m256 sum; \
97+
__m256 mean_vec, var_vec; \
98+
__m128 hi, lo; \
99+
__m256 tmp; \
100+
size_t offset; \
101+
size_t j; \
102+
__m256 reverse_num_vec = \
103+
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
104+
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \
105+
int rest_mask = \
106+
((-1) & (~((~0U) >> (sizeof(int) * 8 - (YMM_FLOAT_BLOCK - rest_))))) & \
107+
0x0ff; \
108+
__m256i mask_vec = _mm256_set_epi32( \
109+
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, \
110+
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, \
111+
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, \
112+
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); \
113+
\
114+
for (int i = 0; i < height; ++i) { \
115+
offset = i * this->num_; \
116+
\
117+
/* get mean */ \
118+
sum = _mm256_setzero_ps(); \
119+
for (j = offset; j < end_ + offset; j += block) { \
120+
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); \
121+
} \
122+
if (rest_ != 0) { \
123+
j = offset + this->num_ - block; \
124+
tmp = _mm256_loadu_ps((const float*)x + j); \
125+
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, (__m256)mask_vec); \
126+
sum = _mm256_add_ps(sum, tmp); \
127+
} \
128+
hi = _mm256_extractf128_ps(sum, 1); \
129+
lo = _mm256_extractf128_ps(sum, 0); \
130+
sum = _mm256_add_ps( \
131+
sum, _mm256_insertf128_ps( \
132+
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
133+
sum = _mm256_hadd_ps(sum, sum); \
134+
sum = _mm256_hadd_ps(sum, sum); \
135+
mean_vec = _mm256_mul_ps(sum, reverse_num_vec); \
136+
mean[i] = *reinterpret_cast<float*>(&mean_vec); \
137+
\
138+
/* get variance */ \
139+
sum = _mm256_setzero_ps(); \
140+
for (j = offset; j < end_ + offset; j += block) { \
141+
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
142+
tmp = _mm256_mul_ps(tmp, tmp); \
143+
sum = _mm256_add_ps(sum, tmp); \
144+
} \
145+
if (rest_ != 0) { \
146+
j = offset + this->num_ - block; \
147+
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
148+
tmp = _mm256_mul_ps(tmp, tmp); \
149+
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, (__m256)mask_vec); \
150+
sum = _mm256_add_ps(sum, tmp); \
151+
} \
152+
hi = _mm256_extractf128_ps(sum, 1); \
153+
lo = _mm256_extractf128_ps(sum, 0); \
154+
sum = _mm256_add_ps( \
155+
sum, _mm256_insertf128_ps( \
156+
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
157+
sum = _mm256_hadd_ps(sum, sum); \
158+
sum = _mm256_hadd_ps(sum, sum); \
159+
var_vec = _mm256_mul_ps(sum, reverse_num_vec); \
160+
var[i] = *reinterpret_cast<float*>(&var_vec); \
161+
\
162+
/* get x_norm and calculate output*/ \
163+
for (j = offset; j < end_ + offset; j += block) { \
164+
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
165+
tmp = _mm256_div_ps( \
166+
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
167+
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
168+
} \
169+
if (rest_ != 0) { \
170+
j = offset + num_ - block; \
171+
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
172+
tmp = _mm256_div_ps( \
173+
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
174+
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
175+
} \
176+
\
177+
if (scale) { \
178+
if (rest_ != 0) { \
179+
j = offset + this->num_ - block; \
180+
tmp = _mm256_loadu_ps((const float*)out + j); \
181+
} \
182+
for (j = offset; j < end_ + offset; j += block) { \
183+
_mm256_storeu_ps( \
184+
reinterpret_cast<float*>(out) + j, \
185+
_mm256_mul_ps( \
186+
_mm256_loadu_ps((const float*)out + j), \
187+
_mm256_loadu_ps((const float*)scale + j - offset))); \
188+
} \
189+
if (rest_ != 0) { \
190+
j = offset + this->num_ - block; \
191+
_mm256_storeu_ps( \
192+
reinterpret_cast<float*>(out) + j, \
193+
_mm256_mul_ps( \
194+
tmp, _mm256_loadu_ps((const float*)scale + j - offset))); \
195+
} \
196+
} \
197+
\
198+
if (bias) { \
199+
if (rest_ != 0) { \
200+
j = offset + this->num_ - block; \
201+
tmp = _mm256_loadu_ps((const float*)out + j); \
202+
} \
203+
for (j = offset; j < end_ + offset; j += block) { \
204+
_mm256_storeu_ps( \
205+
reinterpret_cast<float*>(out) + j, \
206+
_mm256_add_ps( \
207+
_mm256_loadu_ps((const float*)out + j), \
208+
_mm256_loadu_ps((const float*)bias + j - offset))); \
209+
} \
210+
if (rest_ != 0) { \
211+
j = offset + this->num_ - block; \
212+
_mm256_storeu_ps( \
213+
reinterpret_cast<float*>(out) + j, \
214+
_mm256_add_ps( \
215+
tmp, _mm256_loadu_ps((const float*)bias + j - offset))); \
216+
} \
217+
} \
218+
} \
219+
}
220+
221+
#ifdef __AVX__
222+
INTRIAVX_FLOAT(jit::avx, kEQ8);
223+
INTRIAVX_FLOAT(jit::avx, kGT8LT16);
224+
INTRIAVX_FLOAT(jit::avx, kEQ16);
225+
INTRIAVX_FLOAT(jit::avx, kGT16);
226+
#endif
227+
#ifdef __AVX2__
228+
INTRIAVX_FLOAT(jit::avx2, kEQ8);
229+
INTRIAVX_FLOAT(jit::avx2, kGT8LT16);
230+
INTRIAVX_FLOAT(jit::avx2, kEQ16);
231+
INTRIAVX_FLOAT(jit::avx2, kGT16);
232+
#endif
233+
234+
#undef INTRIAVX_FLOAT
235+
236+
REGISTER_JITKERNEL_DEPRECATED(layer_norm, LayerNormKernel);
237+
238+
} // namespace jitkernel
239+
} // namespace math
240+
} // namespace operators
241+
} // namespace paddle

0 commit comments

Comments
 (0)