@@ -34,41 +34,42 @@ namespace jit = platform::jit;
34
34
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
35
35
class VMulKernelImpl : public VMulKernel <T> {
36
36
public:
37
- void Compute (const int n, const T* x, const T* y, T* z) const override {
38
- for (int i = 0 ; i < n; ++i) {
37
+ explicit VMulKernelImpl (int d) : VMulKernel<T>() { this ->num_ = d; }
38
+ void Compute (const T* x, const T* y, T* z) const override {
39
+ for (int i = 0 ; i < this ->num_ ; ++i) {
39
40
z[i] = x[i] * y[i];
40
41
}
41
42
}
42
43
};
43
44
44
45
#ifdef PADDLE_WITH_MKLML
45
- #define MKL_FLOAT (isa, block ) \
46
- template <> \
47
- void VMulKernelImpl<float , isa, block>::Compute( \
48
- const int n, const float * x, const float * y, float * z) const { \
49
- platform::dynload::vsMul (n , x, y, z); \
46
+ #define MKL_FLOAT (isa, block ) \
47
+ template <> \
48
+ void VMulKernelImpl<float , isa, block>::Compute( \
49
+ const float * x, const float * y, float * z) const { \
50
+ platform::dynload::vsMul (this -> num_ , x, y, z); \
50
51
}
51
52
52
- #define MKL_DOUBLE (isa, block ) \
53
- template <> \
54
- void VMulKernelImpl<double , isa, block>::Compute( \
55
- const int n, const double * x, const double * y, double * z) const { \
56
- platform::dynload::vdMul (n , x, y, z); \
53
+ #define MKL_DOUBLE (isa, block ) \
54
+ template <> \
55
+ void VMulKernelImpl<double , isa, block>::Compute( \
56
+ const double * x, const double * y, double * z) const { \
57
+ platform::dynload::vdMul (this -> num_ , x, y, z); \
57
58
}
58
59
59
60
FOR_EACH_ISA (MKL_FLOAT, kGT16 );
60
61
FOR_EACH_ISA_BLOCK (MKL_DOUBLE);
61
62
#endif
62
63
63
- #define INTRI8_FLOAT (isa ) \
64
- template <> \
65
- void VMulKernelImpl<float , isa, kEQ8 >::Compute( \
66
- const int n, const float * x, const float * y, float * z) const { \
67
- __m256 tmpx, tmpy; \
68
- tmpx = _mm256_loadu_ps (x); \
69
- tmpy = _mm256_loadu_ps (y); \
70
- tmpx = _mm256_mul_ps (tmpx, tmpy); \
71
- _mm256_storeu_ps (z, tmpx); \
64
+ #define INTRI8_FLOAT (isa ) \
65
+ template <> \
66
+ void VMulKernelImpl<float , isa, kEQ8 >::Compute( \
67
+ const float * x, const float * y, float * z) const { \
68
+ __m256 tmpx, tmpy; \
69
+ tmpx = _mm256_loadu_ps (x); \
70
+ tmpy = _mm256_loadu_ps (y); \
71
+ tmpx = _mm256_mul_ps (tmpx, tmpy); \
72
+ _mm256_storeu_ps (z, tmpx); \
72
73
}
73
74
74
75
// avx > for > mkl
@@ -90,41 +91,42 @@ INTRI8_FLOAT(jit::avx512f);
90
91
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
91
92
class VAddKernelImpl : public VAddKernel <T> {
92
93
public:
93
- void Compute (const int n, const T* x, const T* y, T* z) const override {
94
- for (int i = 0 ; i < n; ++i) {
94
+ explicit VAddKernelImpl (int d) : VAddKernel<T>() { this ->num_ = d; }
95
+ void Compute (const T* x, const T* y, T* z) const override {
96
+ for (int i = 0 ; i < this ->num_ ; ++i) {
95
97
z[i] = x[i] + y[i];
96
98
}
97
99
}
98
100
};
99
101
100
102
#ifdef PADDLE_WITH_MKLML
101
- #define MKL_FLOAT (isa, block ) \
102
- template <> \
103
- void VAddKernelImpl<float , isa, block>::Compute( \
104
- const int n, const float * x, const float * y, float * z) const { \
105
- platform::dynload::vsAdd (n , x, y, z); \
103
+ #define MKL_FLOAT (isa, block ) \
104
+ template <> \
105
+ void VAddKernelImpl<float , isa, block>::Compute( \
106
+ const float * x, const float * y, float * z) const { \
107
+ platform::dynload::vsAdd (this -> num_ , x, y, z); \
106
108
}
107
109
108
- #define MKL_DOUBLE (isa, block ) \
109
- template <> \
110
- void VAddKernelImpl<double , isa, block>::Compute( \
111
- const int n, const double * x, const double * y, double * z) const { \
112
- platform::dynload::vdAdd (n , x, y, z); \
110
+ #define MKL_DOUBLE (isa, block ) \
111
+ template <> \
112
+ void VAddKernelImpl<double , isa, block>::Compute( \
113
+ const double * x, const double * y, double * z) const { \
114
+ platform::dynload::vdAdd (this -> num_ , x, y, z); \
113
115
}
114
116
115
117
FOR_EACH_ISA (MKL_FLOAT, kGT16 );
116
118
FOR_EACH_ISA_BLOCK (MKL_DOUBLE);
117
119
#endif
118
120
119
- #define INTRI8_FLOAT (isa ) \
120
- template <> \
121
- void VAddKernelImpl<float , isa, kEQ8 >::Compute( \
122
- const int n, const float * x, const float * y, float * z) const { \
123
- __m256 tmpx, tmpy; \
124
- tmpx = _mm256_loadu_ps (x); \
125
- tmpy = _mm256_loadu_ps (y); \
126
- tmpx = _mm256_add_ps (tmpx, tmpy); \
127
- _mm256_storeu_ps (z, tmpx); \
121
+ #define INTRI8_FLOAT (isa ) \
122
+ template <> \
123
+ void VAddKernelImpl<float , isa, kEQ8 >::Compute( \
124
+ const float * x, const float * y, float * z) const { \
125
+ __m256 tmpx, tmpy; \
126
+ tmpx = _mm256_loadu_ps (x); \
127
+ tmpy = _mm256_loadu_ps (y); \
128
+ tmpx = _mm256_add_ps (tmpx, tmpy); \
129
+ _mm256_storeu_ps (z, tmpx); \
128
130
}
129
131
#ifdef __AVX__
130
132
INTRI8_FLOAT (jit::avx);
@@ -145,56 +147,57 @@ INTRI8_FLOAT(jit::avx512f);
145
147
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
146
148
class VScalKernelImpl : public VScalKernel <T> {
147
149
public:
148
- void Compute (const int n, const T a, const T* x, T* y) const override {
149
- for (int i = 0 ; i < n; ++i) {
150
+ explicit VScalKernelImpl (int d) : VScalKernel<T>() { this ->num_ = d; }
151
+ void Compute (const T a, const T* x, T* y) const override {
152
+ for (int i = 0 ; i < this ->num_ ; ++i) {
150
153
y[i] = a * x[i];
151
154
}
152
155
}
153
- void Compute (const int n, const T a, T* x) const override {
154
- for (int i = 0 ; i < n ; ++i) {
156
+ void Compute (const T a, T* x) const override {
157
+ for (int i = 0 ; i < this -> num_ ; ++i) {
155
158
x[i] = a * x[i];
156
159
}
157
160
}
158
161
};
159
162
160
163
#ifdef PADDLE_WITH_MKLML
161
- #define MKL_FLOAT (isa, block ) \
162
- template <> \
163
- void VScalKernelImpl<float , isa, block>::Compute(const int n, const float a, \
164
- float * x) const { \
165
- platform::dynload::cblas_sscal (n , a, x, 1 ); \
164
+ #define MKL_FLOAT (isa, block ) \
165
+ template <> \
166
+ void VScalKernelImpl<float , isa, block>::Compute(const float a, float * x) \
167
+ const { \
168
+ platform::dynload::cblas_sscal (this -> num_ , a, x, 1 ); \
166
169
}
167
170
168
- #define MKL_DOUBLE (isa, block ) \
169
- template <> \
170
- void VScalKernelImpl<double , isa, block>::Compute( \
171
- const int n, const double a, double * x) const { \
172
- platform::dynload::cblas_dscal (n , a, x, 1 ); \
171
+ #define MKL_DOUBLE (isa, block ) \
172
+ template <> \
173
+ void VScalKernelImpl<double , isa, block>::Compute(const double a, double * x) \
174
+ const { \
175
+ platform::dynload::cblas_dscal (this -> num_ , a, x, 1 ); \
173
176
}
174
177
175
178
FOR_EACH_ISA (MKL_FLOAT, kGT16 );
176
179
FOR_EACH_ISA_BLOCK (MKL_DOUBLE);
177
180
#endif
178
181
179
- #define INTRI8_FLOAT (isa ) \
180
- template <> \
181
- void VScalKernelImpl<float , isa, kEQ8 >::Compute( \
182
- const int n, const float a, const float * x, float * y) const { \
183
- __m256 tmp; \
184
- __m256 scalar = _mm256_set1_ps (a); \
185
- tmp = _mm256_loadu_ps (x); \
186
- tmp = _mm256_mul_ps (tmp, scalar); \
187
- _mm256_storeu_ps (y, tmp); \
182
+ #define INTRI8_FLOAT (isa ) \
183
+ template <> \
184
+ void VScalKernelImpl<float , isa, kEQ8 >::Compute( \
185
+ const float a, const float * x, float * y) const { \
186
+ __m256 tmp; \
187
+ __m256 scalar = _mm256_set1_ps (a); \
188
+ tmp = _mm256_loadu_ps (x); \
189
+ tmp = _mm256_mul_ps (tmp, scalar); \
190
+ _mm256_storeu_ps (y, tmp); \
188
191
}
189
- #define INTRI8_INPLACE_FLOAT (isa ) \
190
- template <> \
191
- void VScalKernelImpl<float , isa, kEQ8 >::Compute(const int n, const float a, \
192
- float * x) const { \
193
- __m256 tmp; \
194
- __m256 scalar = _mm256_set1_ps (a); \
195
- tmp = _mm256_loadu_ps (x); \
196
- tmp = _mm256_mul_ps (tmp, scalar); \
197
- _mm256_storeu_ps (x, tmp); \
192
+ #define INTRI8_INPLACE_FLOAT (isa ) \
193
+ template <> \
194
+ void VScalKernelImpl<float , isa, kEQ8 >::Compute(const float a, float * x) \
195
+ const { \
196
+ __m256 tmp; \
197
+ __m256 scalar = _mm256_set1_ps (a); \
198
+ tmp = _mm256_loadu_ps (x); \
199
+ tmp = _mm256_mul_ps (tmp, scalar); \
200
+ _mm256_storeu_ps (x, tmp); \
198
201
}
199
202
200
203
#ifdef __AVX__
@@ -220,32 +223,33 @@ INTRI8_INPLACE_FLOAT(jit::avx512f);
220
223
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
221
224
class VAddBiasKernelImpl : public VAddBiasKernel <T> {
222
225
public:
223
- void Compute (const int n, const T a, const T* x, T* y) const override {
224
- for (int i = 0 ; i < n; ++i) {
226
+ explicit VAddBiasKernelImpl (int d) : VAddBiasKernel<T>() { this ->num_ = d; }
227
+ void Compute (const T a, const T* x, T* y) const override {
228
+ for (int i = 0 ; i < this ->num_ ; ++i) {
225
229
y[i] = x[i] + a;
226
230
}
227
231
}
228
232
};
229
233
230
- #define INTRI8_FLOAT (isa ) \
231
- template <> \
232
- void VAddBiasKernelImpl<float , isa, kEQ8 >::Compute( \
233
- const int n, const float a, const float * x, float * y) const { \
234
- __m256 tmp = _mm256_loadu_ps (x); \
235
- tmp = _mm256_add_ps (tmp, _mm256_set1_ps (a)); \
236
- _mm256_storeu_ps (y, tmp); \
234
+ #define INTRI8_FLOAT (isa ) \
235
+ template <> \
236
+ void VAddBiasKernelImpl<float , isa, kEQ8 >::Compute( \
237
+ const float a, const float * x, float * y) const { \
238
+ __m256 tmp = _mm256_loadu_ps (x); \
239
+ tmp = _mm256_add_ps (tmp, _mm256_set1_ps (a)); \
240
+ _mm256_storeu_ps (y, tmp); \
237
241
}
238
242
239
- #define INTRI16_FLOAT (isa ) \
240
- template <> \
241
- void VAddBiasKernelImpl<float , isa, kEQ16 >::Compute( \
242
- const int n, const float a, const float * x, float * y) const { \
243
- __m256 tmp0 = _mm256_loadu_ps (x); \
244
- __m256 tmp1 = _mm256_loadu_ps (x + 8 ); \
245
- tmp0 = _mm256_add_ps (tmp0, _mm256_set1_ps (a)); \
246
- tmp1 = _mm256_add_ps (tmp1, _mm256_set1_ps (a)); \
247
- _mm256_storeu_ps (y, tmp0); \
248
- _mm256_storeu_ps (y + 8 , tmp1); \
243
+ #define INTRI16_FLOAT (isa ) \
244
+ template <> \
245
+ void VAddBiasKernelImpl<float , isa, kEQ16 >::Compute( \
246
+ const float a, const float * x, float * y) const { \
247
+ __m256 tmp0 = _mm256_loadu_ps (x); \
248
+ __m256 tmp1 = _mm256_loadu_ps (x + 8 ); \
249
+ tmp0 = _mm256_add_ps (tmp0, _mm256_set1_ps (a)); \
250
+ tmp1 = _mm256_add_ps (tmp1, _mm256_set1_ps (a)); \
251
+ _mm256_storeu_ps (y, tmp0); \
252
+ _mm256_storeu_ps (y + 8 , tmp1); \
249
253
}
250
254
251
255
#ifdef __AVX__
0 commit comments