Skip to content

Commit b4751a3

Browse files
committed
fix illegal instruction of rnn2
1 parent 36588b3 commit b4751a3

File tree

2 files changed

+125
-79
lines changed

2 files changed

+125
-79
lines changed

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@ limitations under the License. */
2727
namespace paddle {
2828
namespace operators {
2929
namespace math {
30-
31-
#ifdef __AVX__
32-
namespace detail {
33-
__m256 Exp(__m256 a);
34-
} // namespace detail
35-
#endif
36-
3730
namespace jitkernel {
3831
namespace jit = platform::jit;
3932

@@ -205,7 +198,7 @@ __m256 ExpAVX(__m256 x) {
205198
#ifdef __AVX2__
206199
__m256 ExpAVX2(__m256 x) {
207200
__m256 tmp = _mm256_setzero_ps(), fx;
208-
__m256 one = *reinterpret_cast<const __m256*> _ps256_one;
201+
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one);
209202
__m256i imm0;
210203

211204
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi));
@@ -335,7 +328,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
335328
template <> \
336329
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
337330
const { \
338-
/*use static const??*/ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
331+
/* TODO(TJ): try to use static const*/ \
332+
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
339333
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
340334
__m256 tmp = _mm256_loadu_ps(x); \
341335
INTRI_SIGMOID(tmp, min, max, expisa); \

paddle/fluid/operators/math/jit_kernel_lstm.cc

Lines changed: 122 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,18 @@ limitations under the License. */
2525
namespace paddle {
2626
namespace operators {
2727
namespace math {
28-
#ifdef __AVX__
28+
namespace jitkernel {
2929
namespace detail {
30-
__m256 Exp(__m256 a);
31-
} // namespace detail
30+
#ifdef __AVX__
31+
__m256 ExpAVX(__m256 x);
3232
#endif
3333

34-
namespace jitkernel {
34+
#ifdef __AVX2__
35+
__m256 ExpAVX2(__m256 x);
36+
#endif
37+
38+
} // namespace detail
39+
3540
namespace jit = platform::jit;
3641

3742
#ifdef __AVX__
@@ -43,43 +48,72 @@ class AVXAct {
4348
virtual __m256 Compute(__m256 x) const = 0;
4449
};
4550

46-
template <act_type type>
51+
template <act_type type, jit::cpu_isa_t isa>
4752
class AVXActImpl : public AVXAct {
4853
public:
4954
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
5055
};
5156

52-
template <>
53-
__m256 AVXActImpl<kSigmoid>::Compute(__m256 x) const {
54-
__m256 ones = _mm256_set1_ps(1.0f);
55-
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN));
56-
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX));
57-
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x);
58-
x = detail::Exp(x);
59-
x = _mm256_add_ps(ones, x);
60-
return _mm256_div_ps(ones, x);
61-
}
57+
#define AVX_SIGMOID(isa, expisa) \
58+
template <> \
59+
__m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
60+
__m256 ones = _mm256_set1_ps(1.0f); \
61+
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
62+
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
63+
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
64+
x = expisa(x); \
65+
x = _mm256_add_ps(ones, x); \
66+
return _mm256_div_ps(ones, x); \
67+
}
6268

63-
template <>
64-
__m256 AVXActImpl<kTanh>::Compute(__m256 x) const {
65-
__m256 ones = _mm256_set1_ps(1.0f);
66-
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x);
67-
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT));
68-
x = detail::Exp(x);
69-
x = _mm256_add_ps(ones, x);
70-
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x);
71-
return _mm256_sub_ps(x, ones);
72-
}
69+
#define AVX_TANH(isa, expisa) \
70+
template <> \
71+
__m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
72+
__m256 ones = _mm256_set1_ps(1.0f); \
73+
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
74+
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
75+
x = expisa(x); \
76+
x = _mm256_add_ps(ones, x); \
77+
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
78+
return _mm256_sub_ps(x, ones); \
79+
}
7380

74-
template <>
75-
__m256 AVXActImpl<kRelu>::Compute(__m256 x) const {
76-
return _mm256_max_ps(x, _mm256_setzero_ps());
77-
}
81+
#define AVX_RELU(isa) \
82+
template <> \
83+
__m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \
84+
return _mm256_max_ps(x, _mm256_setzero_ps()); \
85+
}
86+
87+
#define AVX_IDENTITY(isa) \
88+
template <> \
89+
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
90+
return x; \
91+
}
92+
93+
#define FOR_EACH_AVX_ISA(macro_) \
94+
macro_(jit::avx); \
95+
macro_(jit::avx2); \
96+
macro_(jit::avx512f)
97+
98+
FOR_EACH_AVX_ISA(AVX_RELU);
99+
FOR_EACH_AVX_ISA(AVX_IDENTITY);
100+
101+
AVX_SIGMOID(jit::avx, detail::ExpAVX);
102+
AVX_TANH(jit::avx, detail::ExpAVX);
103+
104+
#ifdef __AVX2__
105+
AVX_SIGMOID(jit::avx2, detail::ExpAVX2);
106+
AVX_SIGMOID(jit::avx512f, detail::ExpAVX2);
107+
AVX_TANH(jit::avx2, detail::ExpAVX2);
108+
AVX_TANH(jit::avx512f, detail::ExpAVX2);
109+
#endif
110+
111+
#undef FOR_EACH_AVX_ISA
112+
#undef AVX_IDENTITY
113+
#undef AVX_RELU
114+
#undef AVX_TANH
115+
#undef AVX_SIGMOID
78116

79-
template <>
80-
__m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
81-
return x;
82-
}
83117
#endif
84118

85119
template <typename T>
@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
119153
act_cell_d_ = GetActKernel<T>(act_cell, d);
120154
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
121155
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
122-
#ifdef __AVX__
123-
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
124-
if (type == "sigmoid") {
125-
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid>());
126-
} else if (type == "relu") {
127-
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu>());
128-
} else if (type == "tanh") {
129-
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh>());
130-
} else if (type == "identity" || type == "") {
131-
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity>());
132-
}
133-
PADDLE_THROW("Not support type: %s", type);
134-
};
135-
avx_act_gate_ = GetAVXAct(act_gate);
136-
avx_act_cand_ = GetAVXAct(act_cand);
137-
avx_act_cell_ = GetAVXAct(act_cell);
138-
#endif
139156
}
140157

141158
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> {
175192
#endif
176193
};
177194

178-
#define INTRI8_FLOAT(isa) \
179-
template <> \
180-
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
181-
float* gates, const float* ct_1, float* ct, float* ht, \
182-
const float* wp_data, float* checked) const { \
183-
/* gates: W_ch, W_ih, W_fh, W_oh */ \
184-
__m256 c, i, f, o; \
185-
c = _mm256_loadu_ps(gates); \
186-
i = _mm256_loadu_ps(gates + 8); \
187-
f = _mm256_loadu_ps(gates + 16); \
188-
o = _mm256_loadu_ps(gates + 24); \
189-
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
190-
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
191-
i = _mm256_loadu_ps(ct_1); \
192-
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
193-
f = _mm256_add_ps(c, f); \
194-
_mm256_storeu_ps(ct, f); \
195-
/* H_t = act_cell(C_t) * ogated */ \
196-
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
197-
_mm256_storeu_ps(ht, o); \
195+
#define INTRI8_FLOAT(isa) \
196+
template <> \
197+
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
198+
const std::string& act_gate, const std::string& act_cand, \
199+
const std::string& act_cell, int d) \
200+
: LSTMKernel<float>() { \
201+
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
202+
if (type == "sigmoid") { \
203+
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
204+
} else if (type == "relu") { \
205+
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
206+
} else if (type == "tanh") { \
207+
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
208+
} else if (type == "identity" || type == "") { \
209+
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
210+
} \
211+
PADDLE_THROW("Not support type: %s", type); \
212+
}; \
213+
avx_act_gate_ = GetAVXAct(act_gate); \
214+
avx_act_cand_ = GetAVXAct(act_cand); \
215+
avx_act_cell_ = GetAVXAct(act_cell); \
216+
} \
217+
template <> \
218+
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
219+
float* gates, const float* ct_1, float* ct, float* ht, \
220+
const float* wp_data, float* checked) const { \
221+
/* gates: W_ch, W_ih, W_fh, W_oh */ \
222+
__m256 c, i, f, o; \
223+
c = _mm256_loadu_ps(gates); \
224+
i = _mm256_loadu_ps(gates + 8); \
225+
f = _mm256_loadu_ps(gates + 16); \
226+
o = _mm256_loadu_ps(gates + 24); \
227+
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
228+
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
229+
i = _mm256_loadu_ps(ct_1); \
230+
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
231+
f = _mm256_add_ps(c, f); \
232+
_mm256_storeu_ps(ct, f); \
233+
/* H_t = act_cell(C_t) * ogated */ \
234+
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
235+
_mm256_storeu_ps(ht, o); \
236+
} \
237+
template <> \
238+
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
239+
float* gates, float* ct, float* ht, const float* wp_data) const { \
240+
__m256 c, i, o; \
241+
c = _mm256_loadu_ps(gates); \
242+
i = _mm256_loadu_ps(gates + 8); \
243+
o = _mm256_loadu_ps(gates + 24); \
244+
/* C_t = igated * cgated*/ \
245+
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
246+
_mm256_storeu_ps(ct, c); \
247+
/* H_t = act_cell(C_t) * ogated */ \
248+
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
249+
_mm256_storeu_ps(ht, o); \
198250
}
199251

200252
// TODO(TJ): optimize keq16

0 commit comments

Comments
 (0)