Skip to content

Commit 8dea07f

Browse files
committed
fix comopile
1 parent 612ba41 commit 8dea07f

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,15 +396,15 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
396396
}
397397
} else {
398398
// TODO(TJ): unly workaround, clean me
399-
std::function<void(const T*, const T*, T*, T*)> compute_ctht;
399+
std::function<void(T*, const T*, T*, T*)> compute_ctht;
400400
if (platform::jit::MayIUse(platform::jit::avx) &&
401401
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
402402
act_cell_str == "tanh" && D == 8) {
403403
compute_ctht = math::lstm_compute_ctht<T>;
404404
} else {
405-
compute_ctht = [&](const T* gates, const T* ct_1, T* ct, T* ht) {
405+
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
406406
COMPUTE_CtHt(gates, ct_1, ct, ht);
407-
}
407+
};
408408
}
409409
for (int i = 0; i < N; ++i) {
410410
PROCESS_H0C0

paddle/fluid/operators/math/cpu_lstm_compute.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ namespace math {
2525

2626
namespace detail {
2727
namespace forward {
28-
namespace avx {} // namespace avx
28+
namespace avx {
29+
__m256 Sigmoid(const __m256 a);
30+
__m256 Tanh(const __m256 a);
31+
} // namespace avx
2932
} // namespace forward
3033
} // namespace detail
3134

3235
template <>
33-
void lstm_compute_ctht<float>(const float* gates, const float* ct_1, float* ct,
36+
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
3437
float* ht) {
3538
namespace act = detail::forward::avx;
3639
// gates: W_ch, W_ih, W_fh, W_oh
@@ -52,6 +55,7 @@ void lstm_compute_ctht<float>(const float* gates, const float* ct_1, float* ct,
5255
_mm256_storeu_ps(ht, o);
5356
}
5457
#endif
58+
5559
} // namespace math
5660
} // namespace operators
5761
} // namespace paddle

paddle/fluid/operators/math/cpu_lstm_compute.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,19 @@ namespace math {
2323

2424
// TODO(TJ): ugly workaround, clean me
2525
template <typename T>
26-
void lstm_compute_ctht(const T* gates, const T* ct_1, T* ct, T* ht) {
26+
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
2727
// gates: W_ch, W_ih, W_fh, W_oh
2828
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
2929
vec_tanh<T, platform::jit::avx>(8, gates, gates);
3030
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
31+
const T min = SIGMOID_THRESHOLD_MIN;
32+
const T max = SIGMOID_THRESHOLD_MAX;
3133
for (int d = 0; d < 8; ++d) {
3234
// C_t = C_t-1 * fgated + cand_gated * igated
3335
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
34-
3536
// H_t = act_cell(C_t) * ogated
3637
T tmp = ct[d] * 2;
37-
tmp = static_cast<T>(0) - (tmp < static_cast<T>(SIGMOID_THRESHOLD_MIN))
38-
? min
39-
: ((tmp > static_cast<T>(SIGMOID_THRESHOLD_MAX))
40-
? static_cast<T>(SIGMOID_THRESHOLD_MAX)
41-
: tmp);
38+
tmp = static_cast<T>(0) - (tmp < min) ? min : ((tmp > max) ? max : tmp);
4239
vec_exp<T>(1, &tmp, &tmp);
4340
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
4441
ht[d] = tmp * o[d];

0 commit comments

Comments
 (0)