@@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/math/jit_kernel.h"
16
- #include < functional>
17
16
#include < string>
18
- #include " paddle/fluid/operators/math/cpu_vec.h"
17
+ #include " paddle/fluid/operators/math/jit_kernel_macro.h"
18
+ #include " paddle/fluid/platform/enforce.h"
19
+
20
+ #ifdef __AVX__
21
+ #include < immintrin.h>
22
+ #endif
19
23
20
24
namespace paddle {
21
25
namespace operators {
@@ -24,51 +28,85 @@ namespace jitkernel {
24
28
25
29
namespace jit = platform::jit;
26
30
27
- template <>
28
- LSTMKernel<float >::LSTMKernel(int d, const std::string& act_gate_str,
29
- const std::string& act_cand_str,
30
- const std::string& act_cell_str)
31
- : Kernel(), d_(d) {
32
- d2_ = d * 2 ;
33
- d3_ = d * 3 ;
34
- if (platform::jit::MayIUse (platform::jit::avx512f)) {
35
- math::VecActivations<float , platform::jit::avx512f> act_functor;
36
- act_gate_ = act_functor (act_gate_str);
37
- act_cell_ = act_functor (act_cell_str);
38
- act_cand_ = act_functor (act_cand_str);
39
- } else if (platform::jit::MayIUse (platform::jit::avx2)) {
40
- math::VecActivations<float , platform::jit::avx2> act_functor;
41
- act_gate_ = act_functor (act_gate_str);
42
- act_cell_ = act_functor (act_cell_str);
43
- act_cand_ = act_functor (act_cand_str);
44
- } else if (platform::jit::MayIUse (platform::jit::avx)) {
45
- math::VecActivations<float , platform::jit::avx> act_functor;
46
- act_gate_ = act_functor (act_gate_str);
47
- act_cell_ = act_functor (act_cell_str);
48
- act_cand_ = act_functor (act_cand_str);
49
- // ComputeCtHt = [&](float*gates,const float*ct_1,float*ct, float*ht) {
50
- // // gates: W_ch, W_ih, W_fh, W_oh
51
- // act_gate(d3_, gates + d_, gates + d_);
52
-
53
- // /* C_t = C_t-1 * fgated + cand_gated * igated */
54
- // act_cand(d_, gates, gates);
55
- // blas.VMUL(d_, gates, gates + d_, gates + d_);
56
- // blas.VMUL(d_, ct_1, gates + d2_, gates + d2_);
57
- // blas.VADD(d_, gates + d_, gates + d2_, ct);
58
-
59
- // /* H_t = act_cell(C_t) * ogated */
60
- // act_cell(d_, ct, gates + d2_);
61
- // blas.VMUL(d_, gates + d2_, gates + d3_, ht)
62
- // GET_Ct(ct_1, gates, ct);
63
- // GET_Ht(ct, gates, ht);
64
- // };
65
- } else {
66
- math::VecActivations<float , platform::jit::isa_any> act_functor;
67
- act_gate_ = act_functor (act_gate_str);
68
- act_cell_ = act_functor (act_cell_str);
69
- act_cand_ = act_functor (act_cand_str);
31
+ /* LSTM JitKernel */
32
+ template <typename T, jit::cpu_isa_t isa, jit_block>
33
+ class LSTMKernelImpl : public LSTMKernel <T> {
34
+ public:
35
+ explicit LSTMKernelImpl (int d, const std::string& act_gate,
36
+ const std::string& act_cand,
37
+ const std::string& act_cell)
38
+ : LSTMKernel<T>() {
39
+ d_ = d;
40
+ d2_ = d * 2 ;
41
+ d3_ = d * 3 ;
42
+ auto GetActKernel = [&](const std::string& type,
43
+ int n) -> std::shared_ptr<const VActKernel<T>> {
44
+ if (type == " sigmoid" ) {
45
+ return std::dynamic_pointer_cast<const VActKernel<T>>(
46
+ KernelPool::Instance ().template Get <VSigmoidKernel<T>>(n));
47
+ } else if (type == " relu" ) {
48
+ return std::dynamic_pointer_cast<const VActKernel<T>>(
49
+ KernelPool::Instance ().template Get <VReluKernel<T>>(n));
50
+ } else if (type == " tanh" ) {
51
+ return std::dynamic_pointer_cast<const VActKernel<T>>(
52
+ KernelPool::Instance ().template Get <VTanhKernel<T>>(n));
53
+ } else if (type == " identity" || type == " " ) {
54
+ return std::dynamic_pointer_cast<const VActKernel<T>>(
55
+ KernelPool::Instance ().template Get <VIdentityKernel<T>>(n));
56
+ }
57
+ PADDLE_THROW (" Not support type: %s" , type);
58
+ };
59
+ act_gate_3d_ = GetActKernel (act_gate, d * 3 );
60
+ act_cand_d_ = GetActKernel (act_cand, d);
61
+ act_cell_d_ = GetActKernel (act_cell, d);
62
+ vmul_d_ = KernelPool::Instance ().template Get <VMulKernel<T>>(d);
63
+ vadd_d_ = KernelPool::Instance ().template Get <VAddKernel<T>>(d);
64
+ }
65
+
66
+ void ComputeCtHt (T* gates, const T* ct_1, T* ct, T* ht) const override {
67
+ // gates: W_ch, W_ih, W_fh, W_oh
68
+ act_gate_3d_->Compute (gates + d_, gates + d_);
69
+
70
+ /* C_t = C_t-1 * fgated + cand_gated * igated */
71
+ act_cand_d_->Compute (gates, gates);
72
+ vmul_d_->Compute (gates, gates + d_, gates + d_);
73
+ vmul_d_->Compute (ct_1, gates + d2_, gates + d2_);
74
+ vadd_d_->Compute (gates + d_, gates + d2_, ct);
75
+
76
+ /* H_t = act_cell(C_t) * ogated */
77
+ act_cell_d_->Compute (ct, gates + d2_);
78
+ vmul_d_->Compute (gates + d2_, gates + d3_, ht);
70
79
}
71
- }
80
+
81
+ private:
82
+ int d_, d2_, d3_;
83
+ std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
84
+ std::shared_ptr<const VMulKernel<T>> vmul_d_;
85
+ std::shared_ptr<const VAddKernel<T>> vadd_d_;
86
+ };
87
+
88
+ #define JITKERNEL_DECLARE_LSTM (ker_class, ker_dtype ) \
89
+ template <> \
90
+ std::shared_ptr<const ker_class<ker_dtype>> \
91
+ KernelPool::Get<ker_class<ker_dtype>, int , const std::string&, \
92
+ const std::string&, const std::string&>( \
93
+ int d, const std::string& act_gate, const std::string& act_cand, \
94
+ const std::string& act_cell)
95
+
96
+ #define JITKERNEL_KEY_LSTM (ker_key, dtype_key ) \
97
+ #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell
98
+
99
+ #define JITKERNEL_NEW_LSTM_IMPL (ker, dtype, isa, k ) \
100
+ p = std::dynamic_pointer_cast<ker<dtype>>( \
101
+ std::make_shared<ker##Impl<dtype, isa, k>>(d, act_gate, act_cand, \
102
+ act_cell))
103
+
104
+ REGISTER_JITKERNEL_ARGS (lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
105
+ JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
106
+
107
+ #undef JITKERNEL_DECLARE_LSTM
108
+ #undef JITKERNEL_KEY_LSTM
109
+ #undef JITKERNEL_NEW_LSTM_IMPL
72
110
73
111
} // namespace jitkernel
74
112
} // namespace math
0 commit comments