Skip to content

Commit e6d8aca

Browse files
committed
refine code and fix
1 parent ea7dc9c commit e6d8aca

File tree

5 files changed

+214
-213
lines changed

5 files changed

+214
-213
lines changed

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,32 +64,32 @@ class KernelPool {
6464
template <typename T>
6565
class VMulKernel : public Kernel {
6666
public:
67-
virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0;
67+
virtual void Compute(const T *x, const T *y, T *z) const = 0;
6868
};
6969

7070
template <typename T>
7171
class VAddKernel : public Kernel {
7272
public:
73-
virtual void Compute(const int n, const T *x, const T *y, T *z) const = 0;
73+
virtual void Compute(const T *x, const T *y, T *z) const = 0;
7474
};
7575

7676
template <typename T>
7777
class VScalKernel : public Kernel {
7878
public:
79-
virtual void Compute(const int n, const T a, const T *x, T *y) const = 0;
80-
virtual void Compute(const int n, const T a, T *x) const = 0;
79+
virtual void Compute(const T a, const T *x, T *y) const = 0;
80+
virtual void Compute(const T a, T *x) const = 0;
8181
};
8282

8383
template <typename T>
8484
class VAddBiasKernel : public Kernel {
8585
public:
86-
virtual void Compute(const int n, const T a, const T *x, T *y) const = 0;
86+
virtual void Compute(const T a, const T *x, T *y) const = 0;
8787
};
8888

8989
template <typename T>
9090
class VExpKernel : public Kernel {
9191
public:
92-
virtual void Compute(const int n, const T *x, T *y) const = 0;
92+
virtual void Compute(const T *x, T *y) const = 0;
9393
};
9494

9595
template <typename T>

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 97 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -34,41 +34,42 @@ namespace jit = platform::jit;
3434
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
3535
class VMulKernelImpl : public VMulKernel<T> {
3636
public:
37-
void Compute(const int n, const T* x, const T* y, T* z) const override {
38-
for (int i = 0; i < n; ++i) {
37+
explicit VMulKernelImpl(int d) : VMulKernel<T>() { this->num_ = d; }
38+
void Compute(const T* x, const T* y, T* z) const override {
39+
for (int i = 0; i < this->num_; ++i) {
3940
z[i] = x[i] * y[i];
4041
}
4142
}
4243
};
4344

4445
#ifdef PADDLE_WITH_MKLML
45-
#define MKL_FLOAT(isa, block) \
46-
template <> \
47-
void VMulKernelImpl<float, isa, block>::Compute( \
48-
const int n, const float* x, const float* y, float* z) const { \
49-
platform::dynload::vsMul(n, x, y, z); \
46+
#define MKL_FLOAT(isa, block) \
47+
template <> \
48+
void VMulKernelImpl<float, isa, block>::Compute( \
49+
const float* x, const float* y, float* z) const { \
50+
platform::dynload::vsMul(this->num_, x, y, z); \
5051
}
5152

52-
#define MKL_DOUBLE(isa, block) \
53-
template <> \
54-
void VMulKernelImpl<double, isa, block>::Compute( \
55-
const int n, const double* x, const double* y, double* z) const { \
56-
platform::dynload::vdMul(n, x, y, z); \
53+
#define MKL_DOUBLE(isa, block) \
54+
template <> \
55+
void VMulKernelImpl<double, isa, block>::Compute( \
56+
const double* x, const double* y, double* z) const { \
57+
platform::dynload::vdMul(this->num_, x, y, z); \
5758
}
5859

5960
FOR_EACH_ISA(MKL_FLOAT, kGT16);
6061
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
6162
#endif
6263

63-
#define INTRI8_FLOAT(isa) \
64-
template <> \
65-
void VMulKernelImpl<float, isa, kEQ8>::Compute( \
66-
const int n, const float* x, const float* y, float* z) const { \
67-
__m256 tmpx, tmpy; \
68-
tmpx = _mm256_loadu_ps(x); \
69-
tmpy = _mm256_loadu_ps(y); \
70-
tmpx = _mm256_mul_ps(tmpx, tmpy); \
71-
_mm256_storeu_ps(z, tmpx); \
64+
#define INTRI8_FLOAT(isa) \
65+
template <> \
66+
void VMulKernelImpl<float, isa, kEQ8>::Compute( \
67+
const float* x, const float* y, float* z) const { \
68+
__m256 tmpx, tmpy; \
69+
tmpx = _mm256_loadu_ps(x); \
70+
tmpy = _mm256_loadu_ps(y); \
71+
tmpx = _mm256_mul_ps(tmpx, tmpy); \
72+
_mm256_storeu_ps(z, tmpx); \
7273
}
7374

7475
// avx > for > mkl
@@ -90,41 +91,42 @@ INTRI8_FLOAT(jit::avx512f);
9091
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
9192
class VAddKernelImpl : public VAddKernel<T> {
9293
public:
93-
void Compute(const int n, const T* x, const T* y, T* z) const override {
94-
for (int i = 0; i < n; ++i) {
94+
explicit VAddKernelImpl(int d) : VAddKernel<T>() { this->num_ = d; }
95+
void Compute(const T* x, const T* y, T* z) const override {
96+
for (int i = 0; i < this->num_; ++i) {
9597
z[i] = x[i] + y[i];
9698
}
9799
}
98100
};
99101

100102
#ifdef PADDLE_WITH_MKLML
101-
#define MKL_FLOAT(isa, block) \
102-
template <> \
103-
void VAddKernelImpl<float, isa, block>::Compute( \
104-
const int n, const float* x, const float* y, float* z) const { \
105-
platform::dynload::vsAdd(n, x, y, z); \
103+
#define MKL_FLOAT(isa, block) \
104+
template <> \
105+
void VAddKernelImpl<float, isa, block>::Compute( \
106+
const float* x, const float* y, float* z) const { \
107+
platform::dynload::vsAdd(this->num_, x, y, z); \
106108
}
107109

108-
#define MKL_DOUBLE(isa, block) \
109-
template <> \
110-
void VAddKernelImpl<double, isa, block>::Compute( \
111-
const int n, const double* x, const double* y, double* z) const { \
112-
platform::dynload::vdAdd(n, x, y, z); \
110+
#define MKL_DOUBLE(isa, block) \
111+
template <> \
112+
void VAddKernelImpl<double, isa, block>::Compute( \
113+
const double* x, const double* y, double* z) const { \
114+
platform::dynload::vdAdd(this->num_, x, y, z); \
113115
}
114116

115117
FOR_EACH_ISA(MKL_FLOAT, kGT16);
116118
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
117119
#endif
118120

119-
#define INTRI8_FLOAT(isa) \
120-
template <> \
121-
void VAddKernelImpl<float, isa, kEQ8>::Compute( \
122-
const int n, const float* x, const float* y, float* z) const { \
123-
__m256 tmpx, tmpy; \
124-
tmpx = _mm256_loadu_ps(x); \
125-
tmpy = _mm256_loadu_ps(y); \
126-
tmpx = _mm256_add_ps(tmpx, tmpy); \
127-
_mm256_storeu_ps(z, tmpx); \
121+
#define INTRI8_FLOAT(isa) \
122+
template <> \
123+
void VAddKernelImpl<float, isa, kEQ8>::Compute( \
124+
const float* x, const float* y, float* z) const { \
125+
__m256 tmpx, tmpy; \
126+
tmpx = _mm256_loadu_ps(x); \
127+
tmpy = _mm256_loadu_ps(y); \
128+
tmpx = _mm256_add_ps(tmpx, tmpy); \
129+
_mm256_storeu_ps(z, tmpx); \
128130
}
129131
#ifdef __AVX__
130132
INTRI8_FLOAT(jit::avx);
@@ -145,56 +147,57 @@ INTRI8_FLOAT(jit::avx512f);
145147
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
146148
class VScalKernelImpl : public VScalKernel<T> {
147149
public:
148-
void Compute(const int n, const T a, const T* x, T* y) const override {
149-
for (int i = 0; i < n; ++i) {
150+
explicit VScalKernelImpl(int d) : VScalKernel<T>() { this->num_ = d; }
151+
void Compute(const T a, const T* x, T* y) const override {
152+
for (int i = 0; i < this->num_; ++i) {
150153
y[i] = a * x[i];
151154
}
152155
}
153-
void Compute(const int n, const T a, T* x) const override {
154-
for (int i = 0; i < n; ++i) {
156+
void Compute(const T a, T* x) const override {
157+
for (int i = 0; i < this->num_; ++i) {
155158
x[i] = a * x[i];
156159
}
157160
}
158161
};
159162

160163
#ifdef PADDLE_WITH_MKLML
161-
#define MKL_FLOAT(isa, block) \
162-
template <> \
163-
void VScalKernelImpl<float, isa, block>::Compute(const int n, const float a, \
164-
float* x) const { \
165-
platform::dynload::cblas_sscal(n, a, x, 1); \
164+
#define MKL_FLOAT(isa, block) \
165+
template <> \
166+
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
167+
const { \
168+
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
166169
}
167170

168-
#define MKL_DOUBLE(isa, block) \
169-
template <> \
170-
void VScalKernelImpl<double, isa, block>::Compute( \
171-
const int n, const double a, double* x) const { \
172-
platform::dynload::cblas_dscal(n, a, x, 1); \
171+
#define MKL_DOUBLE(isa, block) \
172+
template <> \
173+
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
174+
const { \
175+
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
173176
}
174177

175178
FOR_EACH_ISA(MKL_FLOAT, kGT16);
176179
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
177180
#endif
178181

179-
#define INTRI8_FLOAT(isa) \
180-
template <> \
181-
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
182-
const int n, const float a, const float* x, float* y) const { \
183-
__m256 tmp; \
184-
__m256 scalar = _mm256_set1_ps(a); \
185-
tmp = _mm256_loadu_ps(x); \
186-
tmp = _mm256_mul_ps(tmp, scalar); \
187-
_mm256_storeu_ps(y, tmp); \
182+
#define INTRI8_FLOAT(isa) \
183+
template <> \
184+
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
185+
const float a, const float* x, float* y) const { \
186+
__m256 tmp; \
187+
__m256 scalar = _mm256_set1_ps(a); \
188+
tmp = _mm256_loadu_ps(x); \
189+
tmp = _mm256_mul_ps(tmp, scalar); \
190+
_mm256_storeu_ps(y, tmp); \
188191
}
189-
#define INTRI8_INPLACE_FLOAT(isa) \
190-
template <> \
191-
void VScalKernelImpl<float, isa, kEQ8>::Compute(const int n, const float a, \
192-
float* x) const { \
193-
__m256 tmp; \
194-
__m256 scalar = _mm256_set1_ps(a); \
195-
tmp = _mm256_loadu_ps(x); \
196-
tmp = _mm256_mul_ps(tmp, scalar); \
197-
_mm256_storeu_ps(x, tmp); \
192+
#define INTRI8_INPLACE_FLOAT(isa) \
193+
template <> \
194+
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
195+
const { \
196+
__m256 tmp; \
197+
__m256 scalar = _mm256_set1_ps(a); \
198+
tmp = _mm256_loadu_ps(x); \
199+
tmp = _mm256_mul_ps(tmp, scalar); \
200+
_mm256_storeu_ps(x, tmp); \
198201
}
199202

200203
#ifdef __AVX__
@@ -220,32 +223,33 @@ INTRI8_INPLACE_FLOAT(jit::avx512f);
220223
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
221224
class VAddBiasKernelImpl : public VAddBiasKernel<T> {
222225
public:
223-
void Compute(const int n, const T a, const T* x, T* y) const override {
224-
for (int i = 0; i < n; ++i) {
226+
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { this->num_ = d; }
227+
void Compute(const T a, const T* x, T* y) const override {
228+
for (int i = 0; i < this->num_; ++i) {
225229
y[i] = x[i] + a;
226230
}
227231
}
228232
};
229233

230-
#define INTRI8_FLOAT(isa) \
231-
template <> \
232-
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
233-
const int n, const float a, const float* x, float* y) const { \
234-
__m256 tmp = _mm256_loadu_ps(x); \
235-
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
236-
_mm256_storeu_ps(y, tmp); \
234+
#define INTRI8_FLOAT(isa) \
235+
template <> \
236+
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
237+
const float a, const float* x, float* y) const { \
238+
__m256 tmp = _mm256_loadu_ps(x); \
239+
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
240+
_mm256_storeu_ps(y, tmp); \
237241
}
238242

239-
#define INTRI16_FLOAT(isa) \
240-
template <> \
241-
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
242-
const int n, const float a, const float* x, float* y) const { \
243-
__m256 tmp0 = _mm256_loadu_ps(x); \
244-
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
245-
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
246-
tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \
247-
_mm256_storeu_ps(y, tmp0); \
248-
_mm256_storeu_ps(y + 8, tmp1); \
243+
#define INTRI16_FLOAT(isa) \
244+
template <> \
245+
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
246+
const float a, const float* x, float* y) const { \
247+
__m256 tmp0 = _mm256_loadu_ps(x); \
248+
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
249+
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
250+
tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \
251+
_mm256_storeu_ps(y, tmp0); \
252+
_mm256_storeu_ps(y + 8, tmp1); \
249253
}
250254

251255
#ifdef __AVX__

0 commit comments

Comments
 (0)