Skip to content

Commit f2adaf1

Browse files
committed
add vrelu and lstm kernel
test=develop
1 parent e6d8aca commit f2adaf1

File tree

6 files changed

+269
-75
lines changed

6 files changed

+269
-75
lines changed

paddle/fluid/operators/math/jit_kernel.cc

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,6 @@ std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
3535
return kers_.at(key);
3636
}
3737

38-
template <>
39-
std::shared_ptr<const LSTMKernel<float>>
40-
KernelPool::Get<LSTMKernel<float>, int, const std::string&, const std::string&,
41-
const std::string&>(int d, const std::string& act_gate,
42-
const std::string& act_cand,
43-
const std::string& act_cell) {
44-
std::string key =
45-
"lstmf" + std::to_string(d) + act_gate + act_cand + act_cell;
46-
if (kers_.find(key) == kers_.end()) {
47-
auto p =
48-
std::make_shared<LSTMKernel<float>>(d, act_gate, act_cand, act_cell);
49-
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)});
50-
return p;
51-
}
52-
return std::dynamic_pointer_cast<const LSTMKernel<float>>(kers_.at(key));
53-
}
54-
5538
} // namespace jitkernel
5639
} // namespace math
5740
} // namespace operators

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,36 +87,45 @@ class VAddBiasKernel : public Kernel {
8787
};
8888

8989
template <typename T>
90-
class VExpKernel : public Kernel {
90+
class VActKernel : public Kernel {
9191
public:
9292
virtual void Compute(const T *x, T *y) const = 0;
9393
};
9494

9595
template <typename T>
96-
class VSigmoidKernel : public Kernel {
96+
class VReluKernel : public VActKernel<T> {
9797
public:
9898
virtual void Compute(const T *x, T *y) const = 0;
9999
};
100100

101101
template <typename T>
102-
class VTanhKernel : public Kernel {
102+
class VIdentityKernel : public VActKernel<T> {
103103
public:
104104
virtual void Compute(const T *x, T *y) const = 0;
105105
};
106106

107107
template <typename T>
108-
class LSTMKernel : public Kernel {
108+
class VExpKernel : public VActKernel<T> {
109109
public:
110-
explicit LSTMKernel(int d, const std::string &act_gate,
111-
const std::string &act_cand, const std::string &act_cell);
110+
virtual void Compute(const T *x, T *y) const = 0;
111+
};
112112

113-
void (*jit_ker)(T *, const T *, T *, T *);
114-
std::function<void(T *, const T *, T *, T *)> ComputeCtHt, ComputeCtHt_NoC0H0;
113+
template <typename T>
114+
class VSigmoidKernel : public VActKernel<T> {
115+
public:
116+
virtual void Compute(const T *x, T *y) const = 0;
117+
};
115118

116-
private:
117-
int d_, d2_, d3_;
118-
std::function<void(const int, const T *, T *)> act_gate_, act_cell_,
119-
act_cand_;
119+
template <typename T>
120+
class VTanhKernel : public VActKernel<T> {
121+
public:
122+
virtual void Compute(const T *x, T *y) const = 0;
123+
};
124+
125+
template <typename T>
126+
class LSTMKernel : public Kernel {
127+
public:
128+
virtual void ComputeCtHt(T *gates, const T *ct_1, T *ct, T *ht) const = 0;
120129
};
121130

122131
} // namespace jitkernel

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,124 @@ INTRI16_FLOAT(jit::avx512f);
266266
#endif
267267
// TODO(TJ): eq16 test and complete avx512
268268

269+
#undef INTRI8_FLOAT
270+
#undef INTRI16_FLOAT
271+
272+
/* VRelu JitKernel */
273+
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
274+
class VReluKernelImpl : public VReluKernel<T> {
275+
public:
276+
explicit VReluKernelImpl(int d) : VReluKernel<T>() { this->num_ = d; }
277+
void Compute(const T* x, T* y) const override {
278+
for (int i = 0; i < this->num_; ++i) {
279+
y[i] = x[i] > 0 ? x[i] : 0;
280+
}
281+
}
282+
};
283+
284+
#define INTRI8_FLOAT(isa) \
285+
template <> \
286+
void VReluKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
287+
const { \
288+
__m256 tmp = _mm256_loadu_ps(x); \
289+
tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \
290+
_mm256_storeu_ps(y, tmp); \
291+
}
292+
293+
#define INTRI16_FLOAT(isa) \
294+
template <> \
295+
void VReluKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
296+
const { \
297+
__m256 zeros = _mm256_setzero_ps(); \
298+
__m256 tmp0 = _mm256_loadu_ps(x); \
299+
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
300+
tmp0 = _mm256_max_ps(tmp0, zeros); \
301+
tmp1 = _mm256_max_ps(tmp1, zeros); \
302+
_mm256_storeu_ps(y, tmp0); \
303+
_mm256_storeu_ps(y + 8, tmp1); \
304+
}
305+
306+
#define INTRI_GT8LT16_FLOAT(isa) \
307+
template <> \
308+
VReluKernelImpl<float, isa, kGT8LT16>::VReluKernelImpl(int d) \
309+
: VReluKernel<float>() { \
310+
this->num_ = d; \
311+
this->end_ = AVX_FLOAT_BLOCK; \
312+
this->rest_ = d - AVX_FLOAT_BLOCK; \
313+
} \
314+
template <> \
315+
void VReluKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
316+
float* y) const { \
317+
__m256 zeros = _mm256_setzero_ps(); \
318+
__m256 tmp0 = _mm256_loadu_ps(x); \
319+
__m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \
320+
tmp0 = _mm256_max_ps(tmp0, zeros); \
321+
tmp1 = _mm256_max_ps(tmp1, zeros); \
322+
_mm256_storeu_ps(y, tmp0); \
323+
_mm256_storeu_ps(y + this->rest_, tmp1); \
324+
}
325+
326+
#define INTRI_GT16_FLOAT(isa) \
327+
template <> \
328+
VReluKernelImpl<float, isa, kGT16>::VReluKernelImpl(int d) \
329+
: VReluKernel<float>() { \
330+
this->num_ = d; \
331+
this->end_ = d - d % AVX_FLOAT_BLOCK; \
332+
this->rest_ = d - AVX_FLOAT_BLOCK; \
333+
} \
334+
template <> \
335+
void VReluKernelImpl<float, isa, kGT16>::Compute(const float* x, float* y) \
336+
const { \
337+
__m256 zeros = _mm256_setzero_ps(); \
338+
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
339+
__m256 tmp = _mm256_loadu_ps(x + i); \
340+
tmp = _mm256_max_ps(tmp, zeros); \
341+
_mm256_storeu_ps(y + i, tmp); \
342+
} \
343+
__m256 tmp = _mm256_loadu_ps(x + this->rest_); \
344+
tmp = _mm256_max_ps(tmp, zeros); \
345+
_mm256_storeu_ps(y + this->rest_, tmp); \
346+
}
347+
348+
#ifdef __AVX__
349+
INTRI8_FLOAT(jit::avx);
350+
INTRI16_FLOAT(jit::avx);
351+
INTRI_GT8LT16_FLOAT(jit::avx);
352+
INTRI_GT16_FLOAT(jit::avx);
353+
#endif
354+
#ifdef __AVX2__
355+
INTRI8_FLOAT(jit::avx2);
356+
INTRI16_FLOAT(jit::avx2);
357+
INTRI_GT8LT16_FLOAT(jit::avx2);
358+
INTRI_GT16_FLOAT(jit::avx2);
359+
#endif
360+
#ifdef __AVX512F__
361+
// TODO(TJ): refine avx512
362+
INTRI8_FLOAT(jit::avx512f);
363+
INTRI16_FLOAT(jit::avx512f);
364+
INTRI_GT8LT16_FLOAT(jit::avx512f);
365+
INTRI_GT16_FLOAT(jit::avx512f);
366+
#endif
367+
269368
#undef INTRI8_FLOAT
270369
#undef INTRI16_FLOAT
271370
#undef INTRI_GT8LT16_FLOAT
272371
#undef INTRI_GT16_FLOAT
273372

373+
/* An empty JitKernel */
374+
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
375+
class VIdentityKernelImpl : public VIdentityKernel<T> {
376+
public:
377+
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; }
378+
void Compute(const T* x, T* y) const override {}
379+
};
380+
274381
REGISTER_JITKERNEL(vmul, VMulKernel);
275382
REGISTER_JITKERNEL(vadd, VAddKernel);
276383
REGISTER_JITKERNEL(vscal, VScalKernel);
277384
REGISTER_JITKERNEL(vaddb, VAddBiasKernel);
385+
REGISTER_JITKERNEL(vrelu, VReluKernel);
386+
REGISTER_JITKERNEL(videntity, VIdentityKernel);
278387

279388
} // namespace jitkernel
280389
} // namespace math

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/jit_kernel.h"
16+
#include <cmath> // for exp
1617
#include <string>
1718
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
1819
#ifdef PADDLE_WITH_MKLML

paddle/fluid/operators/math/jit_kernel_lstm.cc

Lines changed: 84 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/jit_kernel.h"
16-
#include <functional>
1716
#include <string>
18-
#include "paddle/fluid/operators/math/cpu_vec.h"
17+
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
18+
#include "paddle/fluid/platform/enforce.h"
19+
20+
#ifdef __AVX__
21+
#include <immintrin.h>
22+
#endif
1923

2024
namespace paddle {
2125
namespace operators {
@@ -24,51 +28,85 @@ namespace jitkernel {
2428

2529
namespace jit = platform::jit;
2630

27-
template <>
28-
LSTMKernel<float>::LSTMKernel(int d, const std::string& act_gate_str,
29-
const std::string& act_cand_str,
30-
const std::string& act_cell_str)
31-
: Kernel(), d_(d) {
32-
d2_ = d * 2;
33-
d3_ = d * 3;
34-
if (platform::jit::MayIUse(platform::jit::avx512f)) {
35-
math::VecActivations<float, platform::jit::avx512f> act_functor;
36-
act_gate_ = act_functor(act_gate_str);
37-
act_cell_ = act_functor(act_cell_str);
38-
act_cand_ = act_functor(act_cand_str);
39-
} else if (platform::jit::MayIUse(platform::jit::avx2)) {
40-
math::VecActivations<float, platform::jit::avx2> act_functor;
41-
act_gate_ = act_functor(act_gate_str);
42-
act_cell_ = act_functor(act_cell_str);
43-
act_cand_ = act_functor(act_cand_str);
44-
} else if (platform::jit::MayIUse(platform::jit::avx)) {
45-
math::VecActivations<float, platform::jit::avx> act_functor;
46-
act_gate_ = act_functor(act_gate_str);
47-
act_cell_ = act_functor(act_cell_str);
48-
act_cand_ = act_functor(act_cand_str);
49-
// ComputeCtHt = [&](float*gates,const float*ct_1,float*ct, float*ht) {
50-
// // gates: W_ch, W_ih, W_fh, W_oh
51-
// act_gate(d3_, gates + d_, gates + d_);
52-
53-
// /* C_t = C_t-1 * fgated + cand_gated * igated */
54-
// act_cand(d_, gates, gates);
55-
// blas.VMUL(d_, gates, gates + d_, gates + d_);
56-
// blas.VMUL(d_, ct_1, gates + d2_, gates + d2_);
57-
// blas.VADD(d_, gates + d_, gates + d2_, ct);
58-
59-
// /* H_t = act_cell(C_t) * ogated */
60-
// act_cell(d_, ct, gates + d2_);
61-
// blas.VMUL(d_, gates + d2_, gates + d3_, ht)
62-
// GET_Ct(ct_1, gates, ct);
63-
// GET_Ht(ct, gates, ht);
64-
// };
65-
} else {
66-
math::VecActivations<float, platform::jit::isa_any> act_functor;
67-
act_gate_ = act_functor(act_gate_str);
68-
act_cell_ = act_functor(act_cell_str);
69-
act_cand_ = act_functor(act_cand_str);
31+
/* LSTM JitKernel */
32+
template <typename T, jit::cpu_isa_t isa, jit_block>
33+
class LSTMKernelImpl : public LSTMKernel<T> {
34+
public:
35+
explicit LSTMKernelImpl(int d, const std::string& act_gate,
36+
const std::string& act_cand,
37+
const std::string& act_cell)
38+
: LSTMKernel<T>() {
39+
d_ = d;
40+
d2_ = d * 2;
41+
d3_ = d * 3;
42+
auto GetActKernel = [&](const std::string& type,
43+
int n) -> std::shared_ptr<const VActKernel<T>> {
44+
if (type == "sigmoid") {
45+
return std::dynamic_pointer_cast<const VActKernel<T>>(
46+
KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
47+
} else if (type == "relu") {
48+
return std::dynamic_pointer_cast<const VActKernel<T>>(
49+
KernelPool::Instance().template Get<VReluKernel<T>>(n));
50+
} else if (type == "tanh") {
51+
return std::dynamic_pointer_cast<const VActKernel<T>>(
52+
KernelPool::Instance().template Get<VTanhKernel<T>>(n));
53+
} else if (type == "identity" || type == "") {
54+
return std::dynamic_pointer_cast<const VActKernel<T>>(
55+
KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
56+
}
57+
PADDLE_THROW("Not support type: %s", type);
58+
};
59+
act_gate_3d_ = GetActKernel(act_gate, d * 3);
60+
act_cand_d_ = GetActKernel(act_cand, d);
61+
act_cell_d_ = GetActKernel(act_cell, d);
62+
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
63+
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
64+
}
65+
66+
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override {
67+
// gates: W_ch, W_ih, W_fh, W_oh
68+
act_gate_3d_->Compute(gates + d_, gates + d_);
69+
70+
/* C_t = C_t-1 * fgated + cand_gated * igated */
71+
act_cand_d_->Compute(gates, gates);
72+
vmul_d_->Compute(gates, gates + d_, gates + d_);
73+
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_);
74+
vadd_d_->Compute(gates + d_, gates + d2_, ct);
75+
76+
/* H_t = act_cell(C_t) * ogated */
77+
act_cell_d_->Compute(ct, gates + d2_);
78+
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
7079
}
71-
}
80+
81+
private:
82+
int d_, d2_, d3_;
83+
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
84+
std::shared_ptr<const VMulKernel<T>> vmul_d_;
85+
std::shared_ptr<const VAddKernel<T>> vadd_d_;
86+
};
87+
88+
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
89+
template <> \
90+
std::shared_ptr<const ker_class<ker_dtype>> \
91+
KernelPool::Get<ker_class<ker_dtype>, int, const std::string&, \
92+
const std::string&, const std::string&>( \
93+
int d, const std::string& act_gate, const std::string& act_cand, \
94+
const std::string& act_cell)
95+
96+
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \
97+
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell
98+
99+
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \
100+
p = std::dynamic_pointer_cast<ker<dtype>>( \
101+
std::make_shared<ker##Impl<dtype, isa, k>>(d, act_gate, act_cand, \
102+
act_cell))
103+
104+
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
105+
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
106+
107+
#undef JITKERNEL_DECLARE_LSTM
108+
#undef JITKERNEL_KEY_LSTM
109+
#undef JITKERNEL_NEW_LSTM_IMPL
72110

73111
} // namespace jitkernel
74112
} // namespace math

0 commit comments

Comments
 (0)