Skip to content

Commit b37fe30

Browse files
authored
Merge pull request #13690 from wangguibao/fix_cpu_lstm_compute_cc
Avoid multiple definitions of lstm_compute_ctht when linking libpaddle_fluid.so
2 parents 26771f4 + 1940bc2 commit b37fe30

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

paddle/fluid/operators/math/cpu_lstm_compute.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,31 @@ limitations under the License. */
1313

1414
namespace paddle {
1515
namespace operators {
16-
namespace math {} // namespace math
16+
namespace math {
17+
#ifdef __AVX__
18+
template <>
19+
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
20+
float* ht) {
21+
namespace act = detail::forward::avx;
22+
// gates: W_ch, W_ih, W_fh, W_oh
23+
__m256 c, i, f, o;
24+
c = _mm256_loadu_ps(gates);
25+
i = _mm256_loadu_ps(gates + 8);
26+
f = _mm256_loadu_ps(gates + 16);
27+
o = _mm256_loadu_ps(gates + 24);
28+
29+
/* C_t = C_t-1 * fgated + cand_gated * igated*/
30+
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
31+
i = _mm256_loadu_ps(ct_1);
32+
f = _mm256_mul_ps(i, act::Sigmoid(f));
33+
f = _mm256_add_ps(c, f);
34+
_mm256_storeu_ps(ct, f);
35+
36+
/* H_t = act_cell(C_t) * ogated */
37+
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
38+
_mm256_storeu_ps(ht, o);
39+
}
40+
#endif
41+
} // namespace math
1742
} // namespace operators
1843
} // namespace paddle

paddle/fluid/operators/math/cpu_lstm_compute.h

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,32 +48,15 @@ namespace forward {
4848
namespace avx {
4949
__m256 Sigmoid(const __m256 a);
5050
__m256 Tanh(const __m256 a);
51+
5152
} // namespace avx
5253
} // namespace forward
5354
} // namespace detail
5455

5556
template <>
5657
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
57-
float* ht) {
58-
namespace act = detail::forward::avx;
59-
// gates: W_ch, W_ih, W_fh, W_oh
60-
__m256 c, i, f, o;
61-
c = _mm256_loadu_ps(gates);
62-
i = _mm256_loadu_ps(gates + 8);
63-
f = _mm256_loadu_ps(gates + 16);
64-
o = _mm256_loadu_ps(gates + 24);
58+
float* ht);
6559

66-
/* C_t = C_t-1 * fgated + cand_gated * igated*/
67-
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
68-
i = _mm256_loadu_ps(ct_1);
69-
f = _mm256_mul_ps(i, act::Sigmoid(f));
70-
f = _mm256_add_ps(c, f);
71-
_mm256_storeu_ps(ct, f);
72-
73-
/* H_t = act_cell(C_t) * ogated */
74-
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
75-
_mm256_storeu_ps(ht, o);
76-
}
7760
#endif
7861

7962
} // namespace math

0 commit comments

Comments
 (0)