Skip to content

Commit e09cf03

Browse files
committed
refine src and header
1 parent 3db1e41 commit e09cf03

File tree

2 files changed

+69
-58
lines changed

2 files changed

+69
-58
lines changed

paddle/fluid/operators/math/cpu_lstm_compute.cc

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,76 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
16+
#include "paddle/fluid/operators/math/cpu_vec.h"
17+
#include "paddle/fluid/platform/cpu_info.h"
18+
#ifdef __AVX__
19+
#include <immintrin.h>
20+
#endif
1621

1722
namespace paddle {
1823
namespace operators {
19-
namespace math {} // namespace math
24+
namespace math {
25+
26+
// TODO(TJ): ugly workaround, clean me
27+
template <typename T>
28+
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
29+
// gates: W_ch, W_ih, W_fh, W_oh
30+
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
31+
vec_tanh<T, platform::jit::avx>(8, gates, gates);
32+
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
33+
const T min = SIGMOID_THRESHOLD_MIN;
34+
const T max = SIGMOID_THRESHOLD_MAX;
35+
for (int d = 0; d < 8; ++d) {
36+
// C_t = C_t-1 * fgated + cand_gated * igated
37+
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
38+
// H_t = act_cell(C_t) * ogated
39+
T tmp = ct[d] * 2;
40+
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
41+
vec_exp<T>(1, &tmp, &tmp);
42+
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
43+
ht[d] = tmp * o[d];
44+
}
45+
}
46+
47+
#ifdef __AVX__
48+
namespace detail {
49+
namespace forward {
50+
namespace avx {
51+
__m256 Sigmoid(const __m256 a);
52+
__m256 Tanh(const __m256 a);
53+
} // namespace avx
54+
} // namespace forward
55+
} // namespace detail
56+
57+
template <>
58+
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
59+
float* ht) {
60+
namespace act = detail::forward::avx;
61+
// gates: W_ch, W_ih, W_fh, W_oh
62+
__m256 c, i, f, o;
63+
c = _mm256_loadu_ps(gates);
64+
i = _mm256_loadu_ps(gates + 8);
65+
f = _mm256_loadu_ps(gates + 16);
66+
o = _mm256_loadu_ps(gates + 24);
67+
68+
/* C_t = C_t-1 * fgated + cand_gated * igated*/
69+
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
70+
i = _mm256_loadu_ps(ct_1);
71+
f = _mm256_mul_ps(i, act::Sigmoid(f));
72+
f = _mm256_add_ps(c, f);
73+
_mm256_storeu_ps(ct, f);
74+
75+
/* H_t = act_cell(C_t) * ogated */
76+
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
77+
_mm256_storeu_ps(ht, o);
78+
}
79+
#endif
80+
81+
template void lstm_compute_ctht<float>(float* gates, const float* ct_1,
82+
float* ct, float* ht);
83+
template void lstm_compute_ctht<double>(double* gates, const double* ct_1,
84+
double* ct, double* ht);
85+
86+
} // namespace math
2087
} // namespace operators
2188
} // namespace paddle

paddle/fluid/operators/math/cpu_lstm_compute.h

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,70 +14,14 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <string>
17-
#include "paddle/fluid/operators/math/cpu_vec.h"
18-
#include "paddle/fluid/platform/cpu_info.h"
19-
#ifdef __AVX__
20-
#include <immintrin.h>
21-
#endif
2217

2318
namespace paddle {
2419
namespace operators {
2520
namespace math {
2621

2722
// TODO(TJ): ugly workaround, clean me
2823
template <typename T>
29-
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
30-
// gates: W_ch, W_ih, W_fh, W_oh
31-
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
32-
vec_tanh<T, platform::jit::avx>(8, gates, gates);
33-
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
34-
const T min = SIGMOID_THRESHOLD_MIN;
35-
const T max = SIGMOID_THRESHOLD_MAX;
36-
for (int d = 0; d < 8; ++d) {
37-
// C_t = C_t-1 * fgated + cand_gated * igated
38-
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
39-
// H_t = act_cell(C_t) * ogated
40-
T tmp = ct[d] * 2;
41-
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
42-
vec_exp<T>(1, &tmp, &tmp);
43-
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
44-
ht[d] = tmp * o[d];
45-
}
46-
}
47-
48-
#ifdef __AVX__
49-
namespace detail {
50-
namespace forward {
51-
namespace avx {
52-
__m256 Sigmoid(const __m256 a);
53-
__m256 Tanh(const __m256 a);
54-
} // namespace avx
55-
} // namespace forward
56-
} // namespace detail
57-
58-
template <>
59-
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
60-
float* ht) {
61-
namespace act = detail::forward::avx;
62-
// gates: W_ch, W_ih, W_fh, W_oh
63-
__m256 c, i, f, o;
64-
c = _mm256_loadu_ps(gates);
65-
i = _mm256_loadu_ps(gates + 8);
66-
f = _mm256_loadu_ps(gates + 16);
67-
o = _mm256_loadu_ps(gates + 24);
68-
69-
/* C_t = C_t-1 * fgated + cand_gated * igated*/
70-
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
71-
i = _mm256_loadu_ps(ct_1);
72-
f = _mm256_mul_ps(i, act::Sigmoid(f));
73-
f = _mm256_add_ps(c, f);
74-
_mm256_storeu_ps(ct, f);
75-
76-
/* H_t = act_cell(C_t) * ogated */
77-
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
78-
_mm256_storeu_ps(ht, o);
79-
}
80-
#endif
24+
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht);
8125

8226
} // namespace math
8327
} // namespace operators

0 commit comments

Comments
 (0)