Skip to content

Commit ff858d3

Browse files
committed
fix bug and enable on batch mode as well
1 parent 8dea07f commit ff858d3

File tree

3 files changed

+52
-44
lines changed

3 files changed

+52
-44
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,12 +543,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
543543
MOVE_ONE_STEP;
544544
}
545545
} else {
546+
// TODO(TJ): unly workaround, clean me
547+
std::function<void(T*, const T*, T*, T*)> compute_ctht;
548+
if (platform::jit::MayIUse(platform::jit::avx) &&
549+
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
550+
act_cell_str == "tanh" && D == 8) {
551+
compute_ctht = math::lstm_compute_ctht<T>;
552+
} else {
553+
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
554+
COMPUTE_CtHt(gates, ct_1, ct, ht);
555+
};
556+
}
546557
for (int step = tstart; step < max_seq_len; ++step) {
547558
const int cur_bs = batch_starts[step + 1] - batch_starts[step];
548559
GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
549560
DEFINE_CUR;
550561
for (int i = 0; i < cur_bs; ++i) {
551-
COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data,
562+
compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data,
552563
cur_h_out_data);
553564
MOVE_ONE_BATCH;
554565
}

paddle/fluid/operators/math/cpu_lstm_compute.cc

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
16-
#ifdef __AVX__
17-
#include <immintrin.h>
18-
#endif
16+
1917
namespace paddle {
2018
namespace operators {
21-
namespace math {
22-
23-
#ifdef __AVX__
24-
// TODO(TJ): ugly workaround, clean me
25-
26-
namespace detail {
27-
namespace forward {
28-
namespace avx {
29-
__m256 Sigmoid(const __m256 a);
30-
__m256 Tanh(const __m256 a);
31-
} // namespace avx
32-
} // namespace forward
33-
} // namespace detail
34-
35-
template <>
36-
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
37-
float* ht) {
38-
namespace act = detail::forward::avx;
39-
// gates: W_ch, W_ih, W_fh, W_oh
40-
__m256 c, i, f, o;
41-
c = _mm256_loadu_ps(gates);
42-
i = _mm256_loadu_ps(gates + 8);
43-
f = _mm256_loadu_ps(gates + 16);
44-
o = _mm256_loadu_ps(gates + 24);
45-
46-
/* C_t = C_t-1 * fgated + cand_gated * igated*/
47-
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
48-
i = _mm256_loadu_ps(ct_1);
49-
f = _mm256_mul_ps(i, act::Sigmoid(f));
50-
f = _mm256_add_ps(c, f);
51-
_mm256_storeu_ps(ct, f);
52-
53-
/* H_t = act_cell(C_t) * ogated */
54-
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
55-
_mm256_storeu_ps(ht, o);
56-
}
57-
#endif
58-
59-
} // namespace math
19+
namespace math {} // namespace math
6020
} // namespace operators
6121
} // namespace paddle

paddle/fluid/operators/math/cpu_lstm_compute.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ limitations under the License. */
1616
#include <string>
1717
#include "paddle/fluid/operators/math/cpu_vec.h"
1818
#include "paddle/fluid/platform/cpu_info.h"
19+
#ifdef __AVX__
20+
#include <immintrin.h>
21+
#endif
1922

2023
namespace paddle {
2124
namespace operators {
@@ -35,13 +38,47 @@ void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
3538
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
3639
// H_t = act_cell(C_t) * ogated
3740
T tmp = ct[d] * 2;
38-
tmp = static_cast<T>(0) - (tmp < min) ? min : ((tmp > max) ? max : tmp);
41+
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
3942
vec_exp<T>(1, &tmp, &tmp);
4043
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
4144
ht[d] = tmp * o[d];
4245
}
4346
}
4447

48+
#ifdef __AVX__
49+
namespace detail {
50+
namespace forward {
51+
namespace avx {
52+
__m256 Sigmoid(const __m256 a);
53+
__m256 Tanh(const __m256 a);
54+
} // namespace avx
55+
} // namespace forward
56+
} // namespace detail
57+
58+
template <>
59+
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
60+
float* ht) {
61+
namespace act = detail::forward::avx;
62+
// gates: W_ch, W_ih, W_fh, W_oh
63+
__m256 c, i, f, o;
64+
c = _mm256_loadu_ps(gates);
65+
i = _mm256_loadu_ps(gates + 8);
66+
f = _mm256_loadu_ps(gates + 16);
67+
o = _mm256_loadu_ps(gates + 24);
68+
69+
/* C_t = C_t-1 * fgated + cand_gated * igated*/
70+
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
71+
i = _mm256_loadu_ps(ct_1);
72+
f = _mm256_mul_ps(i, act::Sigmoid(f));
73+
f = _mm256_add_ps(c, f);
74+
_mm256_storeu_ps(ct, f);
75+
76+
/* H_t = act_cell(C_t) * ogated */
77+
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
78+
_mm256_storeu_ps(ht, o);
79+
}
80+
#endif
81+
4582
} // namespace math
4683
} // namespace operators
4784
} // namespace paddle

0 commit comments

Comments
 (0)