@@ -15,6 +15,7 @@ limitations under the License. */
15
15
#include " paddle/fluid/operators/math/jit_kernel.h"
16
16
#include < string>
17
17
#include " paddle/fluid/operators/math/jit_kernel_macro.h"
18
+ #include " paddle/fluid/operators/math/jit_kernel_refer.h"
18
19
#include " paddle/fluid/platform/enforce.h"
19
20
20
21
#ifdef PADDLE_WITH_XBYAK
@@ -31,49 +32,6 @@ namespace math {
31
32
namespace jitkernel {
32
33
namespace jit = platform::jit;
33
34
34
- template <typename T>
35
- void VMulRefer (const T* x, const T* y, T* z, int n) {
36
- for (int i = 0 ; i < n; ++i) {
37
- z[i] = x[i] * y[i];
38
- }
39
- }
40
-
41
- template <typename T>
42
- void VAddRefer (const T* x, const T* y, T* z, int n) {
43
- for (int i = 0 ; i < n; ++i) {
44
- z[i] = x[i] + y[i];
45
- }
46
- }
47
-
48
- template <typename T>
49
- void VAddReluRefer (const T* x, const T* y, T* z, int n) {
50
- for (int i = 0 ; i < n; ++i) {
51
- z[i] = x[i] + y[i];
52
- z[i] = z[i] > 0 ? z[i] : 0 ;
53
- }
54
- }
55
-
56
- template <typename T>
57
- void VScalRefer (const T* a, const T* x, T* y, int n) {
58
- for (int i = 0 ; i < n; ++i) {
59
- y[i] = a[0 ] * x[i];
60
- }
61
- }
62
-
63
- template <typename T>
64
- void VAddBiasRefer (const T* a, const T* x, T* y, int n) {
65
- for (int i = 0 ; i < n; ++i) {
66
- y[i] = a[0 ] + x[i];
67
- }
68
- }
69
-
70
- template <typename T>
71
- void VReluRefer (const T* x, T* y, int n) {
72
- for (int i = 0 ; i < n; ++i) {
73
- y[i] = x[i] > 0 ? x[i] : 0 ;
74
- }
75
- }
76
-
77
35
#ifdef PADDLE_WITH_MKLML
78
36
template <typename T>
79
37
void VMulMKL (const T* x, const T* y, T* z, int n);
@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
109
67
if (x == y) {
110
68
platform::dynload::cblas_sscal (n, *a, y, 1 );
111
69
} else {
112
- VScalRefer <float >(a, x, y, n);
70
+ refer::VScal <float >(a, x, y, n);
113
71
}
114
72
}
115
73
@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
118
76
if (x == y) {
119
77
platform::dynload::cblas_dscal (n, *a, y, 1 );
120
78
} else {
121
- VScalRefer <double >(a, x, y, n);
79
+ refer::VScal <double >(a, x, y, n);
122
80
}
123
81
}
124
82
@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> {
147
105
return ;
148
106
}
149
107
#endif
150
- this ->Compute = VMulRefer <T>;
108
+ this ->Compute = refer::VMul <T>;
151
109
}
152
110
153
111
#ifdef PADDLE_WITH_XBYAK
@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> {
198
156
return ;
199
157
}
200
158
#endif
201
- this ->Compute = VAddRefer <T>;
159
+ this ->Compute = refer::VAdd <T>;
202
160
}
203
161
#ifdef PADDLE_WITH_XBYAK
204
162
@@ -242,7 +200,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
242
200
return ;
243
201
}
244
202
#endif
245
- this ->Compute = VAddReluRefer <T>;
203
+ this ->Compute = refer::VAddRelu <T>;
246
204
}
247
205
#ifdef PADDLE_WITH_XBYAK
248
206
@@ -280,7 +238,7 @@ class VScalKernelImpl : public VScalKernel<T> {
280
238
return ;
281
239
}
282
240
#endif
283
- this ->Compute = VScalRefer <T>;
241
+ this ->Compute = refer::VScal <T>;
284
242
}
285
243
#ifdef PADDLE_WITH_XBYAK
286
244
@@ -324,7 +282,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
324
282
}
325
283
#endif
326
284
327
- this ->Compute = VAddBiasRefer <T>;
285
+ this ->Compute = refer::VAddBias <T>;
328
286
}
329
287
#ifdef PADDLE_WITH_XBYAK
330
288
@@ -358,7 +316,7 @@ class VReluKernelImpl : public VReluKernel<T> {
358
316
}
359
317
#endif
360
318
361
- this ->Compute = VReluRefer <T>;
319
+ this ->Compute = refer::VRelu <T>;
362
320
}
363
321
#ifdef PADDLE_WITH_XBYAK
364
322
@@ -374,16 +332,13 @@ bool VReluKernelImpl<float>::useJIT(int d) {
374
332
}
375
333
#endif
376
334
377
- template <typename T>
378
- inline void VIdentityRefer (const T* x, T* y, int n) {}
379
-
380
335
/* An empty JitKernel */
381
336
template <typename T>
382
337
class VIdentityKernelImpl : public VIdentityKernel <T> {
383
338
public:
384
339
JITKERNEL_DECLARE_STATIC_FUNC;
385
340
explicit VIdentityKernelImpl (int d) : VIdentityKernel<T>() {
386
- this ->Compute = VIdentityRefer <T>;
341
+ this ->Compute = refer::VIdentity <T>;
387
342
}
388
343
};
389
344
0 commit comments