@@ -13,9 +13,76 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/math/cpu_lstm_compute.h"
16
+ #include " paddle/fluid/operators/math/cpu_vec.h"
17
+ #include " paddle/fluid/platform/cpu_info.h"
18
+ #ifdef __AVX__
19
+ #include < immintrin.h>
20
+ #endif
16
21
17
22
namespace paddle {
18
23
namespace operators {
19
- namespace math {} // namespace math
24
+ namespace math {
25
+
26
+ // TODO(TJ): ugly workaround, clean me
27
+ template <typename T>
28
+ void lstm_compute_ctht (T* gates, const T* ct_1, T* ct, T* ht) {
29
+ // gates: W_ch, W_ih, W_fh, W_oh
30
+ vec_sigmoid<T, platform::jit::avx>(24 , gates + 8 , gates + 8 );
31
+ vec_tanh<T, platform::jit::avx>(8 , gates, gates);
32
+ const T *i = gates + 8 , *f = gates + 16 , *o = gates + 24 ;
33
+ const T min = SIGMOID_THRESHOLD_MIN;
34
+ const T max = SIGMOID_THRESHOLD_MAX;
35
+ for (int d = 0 ; d < 8 ; ++d) {
36
+ // C_t = C_t-1 * fgated + cand_gated * igated
37
+ ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
38
+ // H_t = act_cell(C_t) * ogated
39
+ T tmp = ct[d] * 2 ;
40
+ tmp = static_cast <T>(0 ) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
41
+ vec_exp<T>(1 , &tmp, &tmp);
42
+ tmp = static_cast <T>(2 ) / (static_cast <T>(1 ) + tmp) - static_cast <T>(1 );
43
+ ht[d] = tmp * o[d];
44
+ }
45
+ }
46
+
47
+ #ifdef __AVX__
48
+ namespace detail {
49
+ namespace forward {
50
+ namespace avx {
51
+ __m256 Sigmoid (const __m256 a);
52
+ __m256 Tanh (const __m256 a);
53
+ } // namespace avx
54
+ } // namespace forward
55
+ } // namespace detail
56
+
57
+ template <>
58
+ void lstm_compute_ctht<float >(float * gates, const float * ct_1, float * ct,
59
+ float * ht) {
60
+ namespace act = detail::forward::avx;
61
+ // gates: W_ch, W_ih, W_fh, W_oh
62
+ __m256 c, i, f, o;
63
+ c = _mm256_loadu_ps (gates);
64
+ i = _mm256_loadu_ps (gates + 8 );
65
+ f = _mm256_loadu_ps (gates + 16 );
66
+ o = _mm256_loadu_ps (gates + 24 );
67
+
68
+ /* C_t = C_t-1 * fgated + cand_gated * igated*/
69
+ c = _mm256_mul_ps (act::Tanh (c), act::Sigmoid (i));
70
+ i = _mm256_loadu_ps (ct_1);
71
+ f = _mm256_mul_ps (i, act::Sigmoid (f));
72
+ f = _mm256_add_ps (c, f);
73
+ _mm256_storeu_ps (ct, f);
74
+
75
+ /* H_t = act_cell(C_t) * ogated */
76
+ o = _mm256_mul_ps (act::Tanh (f), act::Sigmoid (o));
77
+ _mm256_storeu_ps (ht, o);
78
+ }
79
+ #endif
80
+
81
+ template void lstm_compute_ctht<float >(float * gates, const float * ct_1,
82
+ float * ct, float * ht);
83
+ template void lstm_compute_ctht<double >(double * gates, const double * ct_1,
84
+ double * ct, double * ht);
85
+
86
+ } // namespace math
20
87
} // namespace operators
21
88
} // namespace paddle
0 commit comments