Skip to content

Commit ce31deb

Browse files
committed
refine refer code and add lstm refer code
test=develop
1 parent c2cfb03 commit ce31deb

File tree

5 files changed

+220
-201
lines changed

5 files changed

+220
-201
lines changed

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 10 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/math/jit_kernel.h"
1616
#include <string>
1717
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
18+
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
1819
#include "paddle/fluid/platform/enforce.h"
1920

2021
#ifdef PADDLE_WITH_XBYAK
@@ -31,49 +32,6 @@ namespace math {
3132
namespace jitkernel {
3233
namespace jit = platform::jit;
3334

34-
template <typename T>
35-
void VMulRefer(const T* x, const T* y, T* z, int n) {
36-
for (int i = 0; i < n; ++i) {
37-
z[i] = x[i] * y[i];
38-
}
39-
}
40-
41-
template <typename T>
42-
void VAddRefer(const T* x, const T* y, T* z, int n) {
43-
for (int i = 0; i < n; ++i) {
44-
z[i] = x[i] + y[i];
45-
}
46-
}
47-
48-
template <typename T>
49-
void VAddReluRefer(const T* x, const T* y, T* z, int n) {
50-
for (int i = 0; i < n; ++i) {
51-
z[i] = x[i] + y[i];
52-
z[i] = z[i] > 0 ? z[i] : 0;
53-
}
54-
}
55-
56-
template <typename T>
57-
void VScalRefer(const T* a, const T* x, T* y, int n) {
58-
for (int i = 0; i < n; ++i) {
59-
y[i] = a[0] * x[i];
60-
}
61-
}
62-
63-
template <typename T>
64-
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
65-
for (int i = 0; i < n; ++i) {
66-
y[i] = a[0] + x[i];
67-
}
68-
}
69-
70-
template <typename T>
71-
void VReluRefer(const T* x, T* y, int n) {
72-
for (int i = 0; i < n; ++i) {
73-
y[i] = x[i] > 0 ? x[i] : 0;
74-
}
75-
}
76-
7735
#ifdef PADDLE_WITH_MKLML
7836
template <typename T>
7937
void VMulMKL(const T* x, const T* y, T* z, int n);
@@ -109,7 +67,7 @@ void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
10967
if (x == y) {
11068
platform::dynload::cblas_sscal(n, *a, y, 1);
11169
} else {
112-
VScalRefer<float>(a, x, y, n);
70+
refer::VScal<float>(a, x, y, n);
11371
}
11472
}
11573

@@ -118,7 +76,7 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
11876
if (x == y) {
11977
platform::dynload::cblas_dscal(n, *a, y, 1);
12078
} else {
121-
VScalRefer<double>(a, x, y, n);
79+
refer::VScal<double>(a, x, y, n);
12280
}
12381
}
12482

@@ -147,7 +105,7 @@ class VMulKernelImpl : public VMulKernel<T> {
147105
return;
148106
}
149107
#endif
150-
this->Compute = VMulRefer<T>;
108+
this->Compute = refer::VMul<T>;
151109
}
152110

153111
#ifdef PADDLE_WITH_XBYAK
@@ -198,7 +156,7 @@ class VAddKernelImpl : public VAddKernel<T> {
198156
return;
199157
}
200158
#endif
201-
this->Compute = VAddRefer<T>;
159+
this->Compute = refer::VAdd<T>;
202160
}
203161
#ifdef PADDLE_WITH_XBYAK
204162

@@ -242,7 +200,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
242200
return;
243201
}
244202
#endif
245-
this->Compute = VAddReluRefer<T>;
203+
this->Compute = refer::VAddRelu<T>;
246204
}
247205
#ifdef PADDLE_WITH_XBYAK
248206

@@ -280,7 +238,7 @@ class VScalKernelImpl : public VScalKernel<T> {
280238
return;
281239
}
282240
#endif
283-
this->Compute = VScalRefer<T>;
241+
this->Compute = refer::VScal<T>;
284242
}
285243
#ifdef PADDLE_WITH_XBYAK
286244

@@ -324,7 +282,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
324282
}
325283
#endif
326284

327-
this->Compute = VAddBiasRefer<T>;
285+
this->Compute = refer::VAddBias<T>;
328286
}
329287
#ifdef PADDLE_WITH_XBYAK
330288

@@ -358,7 +316,7 @@ class VReluKernelImpl : public VReluKernel<T> {
358316
}
359317
#endif
360318

361-
this->Compute = VReluRefer<T>;
319+
this->Compute = refer::VRelu<T>;
362320
}
363321
#ifdef PADDLE_WITH_XBYAK
364322

@@ -374,16 +332,13 @@ bool VReluKernelImpl<float>::useJIT(int d) {
374332
}
375333
#endif
376334

377-
template <typename T>
378-
inline void VIdentityRefer(const T* x, T* y, int n) {}
379-
380335
/* An empty JitKernel */
381336
template <typename T>
382337
class VIdentityKernelImpl : public VIdentityKernel<T> {
383338
public:
384339
JITKERNEL_DECLARE_STATIC_FUNC;
385340
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
386-
this->Compute = VIdentityRefer<T>;
341+
this->Compute = refer::VIdentity<T>;
387342
}
388343
};
389344

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/jit_kernel.h"
16-
#include <cmath> // for exp
1716
#include <string>
1817
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
18+
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
1919

2020
#ifdef PADDLE_WITH_XBYAK
2121
#include "paddle/fluid/operators/math/jit_code.h"
@@ -35,38 +35,6 @@ namespace math {
3535
namespace jitkernel {
3636
namespace jit = platform::jit;
3737

38-
// TODO(TJ): move refer codes to one file
39-
// Refer code only focus on correctness
40-
template <typename T>
41-
void VExpRefer(const T* x, T* y, int n) {
42-
for (int i = 0; i < n; ++i) {
43-
y[i] = std::exp(x[i]);
44-
}
45-
}
46-
47-
template <typename T>
48-
void VSigmoidRefer(const T* x, T* y, int n) {
49-
// y = 1 / (1 + e^-x)
50-
const T min = SIGMOID_THRESHOLD_MIN;
51-
const T max = SIGMOID_THRESHOLD_MAX;
52-
for (int i = 0; i < n; ++i) {
53-
T tmp = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
54-
y[i] = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
55-
}
56-
}
57-
58-
template <typename T>
59-
void VTanhRefer(const T* x, T* y, int n) {
60-
// y = 2 * sigmoid(2x) - 1
61-
for (int i = 0; i < n; ++i) {
62-
y[i] = static_cast<T>(2) * x[i];
63-
}
64-
VSigmoidRefer(y, y, n);
65-
for (int i = 0; i < n; ++i) {
66-
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
67-
}
68-
}
69-
7038
#ifdef PADDLE_WITH_MKLML
7139
// try to use MKL to speedup
7240
template <typename T>
@@ -129,7 +97,7 @@ class VExpKernelImpl : public VExpKernel<T> {
12997
return;
13098
}
13199
#endif
132-
this->Compute = VExpRefer<T>;
100+
this->Compute = refer::VExp<T>;
133101
}
134102

135103
#ifdef PADDLE_WITH_XBYAK
@@ -182,7 +150,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
182150
return;
183151
}
184152
#endif
185-
this->Compute = VSigmoidRefer<T>;
153+
this->Compute = refer::VSigmoid<T>;
186154
}
187155

188156
#ifdef PADDLE_WITH_XBYAK
@@ -234,7 +202,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
234202
return;
235203
}
236204
#endif
237-
this->Compute = VTanhRefer<T>;
205+
this->Compute = refer::VTanh<T>;
238206
}
239207

240208
#ifdef PADDLE_WITH_XBYAK

paddle/fluid/operators/math/jit_kernel_impl.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@ typedef struct {
3838
void* checked{nullptr};
3939
} lstm_t;
4040

41-
typedef struct {
41+
typedef struct lstm_attr_s {
4242
int d;
4343
std::string act_gate, act_cand, act_cell;
44+
lstm_attr_s() = default;
45+
lstm_attr_s(int _d, const std::string& _act_gate,
46+
const std::string& _act_cand, const std::string& _act_cell)
47+
: d(_d), act_gate(_act_gate), act_cand(_act_cand), act_cell(_act_cell) {}
4448
} lstm_attr_t;
4549

4650
} // namespace jitkernel

0 commit comments

Comments
 (0)