@@ -86,9 +86,9 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
86
86
template <typename T, jit::cpu_isa_t isa, jit_block>
87
87
class LSTMKernelImpl : public LSTMKernel <T> {
88
88
public:
89
- explicit LSTMKernelImpl (int d, const std::string& act_gate,
89
+ explicit LSTMKernelImpl (const std::string& act_gate,
90
90
const std::string& act_cand,
91
- const std::string& act_cell)
91
+ const std::string& act_cell, int d )
92
92
: LSTMKernel<T>() {
93
93
d_ = d;
94
94
d2_ = d * 2 ;
@@ -134,7 +134,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
134
134
#endif
135
135
}
136
136
137
- void ComputeCtHt (T* gates, const T* ct_1, T* ct, T* ht) const override {
137
+ void ComputeCtHt (T* gates, const T* ct_1, T* ct, T* ht,
138
+ T* checked) const override {
138
139
// gates: W_ch, W_ih, W_fh, W_oh
139
140
act_gate_3d_->Compute (gates + d_, gates + d_);
140
141
@@ -162,7 +163,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
162
163
#define INTRI8_FLOAT (isa ) \
163
164
template <> \
164
165
void LSTMKernelImpl<float , isa, kEQ8 >::ComputeCtHt( \
165
- float * gates, const float * ct_1, float * ct, float * ht) const { \
166
+ float * gates, const float * ct_1, float * ct, float * ht, float * checked) \
167
+ const { \
166
168
/* gates: W_ch, W_ih, W_fh, W_oh */ \
167
169
__m256 c, i, f, o; \
168
170
c = _mm256_loadu_ps (gates); \
@@ -192,21 +194,86 @@ INTRI8_FLOAT(jit::avx2);
192
194
INTRI8_FLOAT (jit::avx512f);
193
195
#endif
194
196
195
- #define JITKERNEL_DECLARE_LSTM (ker_class, ker_dtype ) \
196
- template <> \
197
- std::shared_ptr<const ker_class<ker_dtype>> \
198
- KernelPool::Get<ker_class<ker_dtype>, int , const std::string&, \
199
- const std::string&, const std::string&>( \
200
- int d, const std::string& act_gate, const std::string& act_cand, \
201
- const std::string& act_cell)
197
+ /* Peephole JitKernel */
198
+ template <typename T, jit::cpu_isa_t isa, jit_block>
199
+ class PeepholeKernelImpl : public LSTMKernel <T> {
200
+ public:
201
+ explicit PeepholeKernelImpl (const std::string& act_gate,
202
+ const std::string& act_cand,
203
+ const std::string& act_cell, int d)
204
+ : LSTMKernel<T>() {
205
+ d_ = d;
206
+ d2_ = d * 2 ;
207
+ d3_ = d * 3 ;
208
+ auto GetActKernel = [&](const std::string& type,
209
+ int n) -> std::shared_ptr<const VActKernel<T>> {
210
+ if (type == " sigmoid" ) {
211
+ return std::dynamic_pointer_cast<const VActKernel<T>>(
212
+ KernelPool::Instance ().template Get <VSigmoidKernel<T>>(n));
213
+ } else if (type == " relu" ) {
214
+ return std::dynamic_pointer_cast<const VActKernel<T>>(
215
+ KernelPool::Instance ().template Get <VReluKernel<T>>(n));
216
+ } else if (type == " tanh" ) {
217
+ return std::dynamic_pointer_cast<const VActKernel<T>>(
218
+ KernelPool::Instance ().template Get <VTanhKernel<T>>(n));
219
+ } else if (type == " identity" || type == " " ) {
220
+ return std::dynamic_pointer_cast<const VActKernel<T>>(
221
+ KernelPool::Instance ().template Get <VIdentityKernel<T>>(n));
222
+ }
223
+ PADDLE_THROW (" Not support type: %s" , type);
224
+ };
225
+ act_gate_3d_ = GetActKernel (act_gate, d * 3 );
226
+ act_cand_d_ = GetActKernel (act_cand, d);
227
+ act_cell_d_ = GetActKernel (act_cell, d);
228
+ vmul_d_ = KernelPool::Instance ().template Get <VMulKernel<T>>(d);
229
+ vadd_d_ = KernelPool::Instance ().template Get <VAddKernel<T>>(d);
230
+ }
231
+
232
+ void ComputeCtHt (T* gates, const T* ct_1, T* ct, T* ht,
233
+ T* checked) const override {
234
+ // gates: W_ch, W_ih, W_fh, W_oh
235
+ act_gate_3d_->Compute (gates + d_, gates + d_);
236
+
237
+ /* C_t = C_t-1 * fgated + cand_gated * igated */
238
+ act_cand_d_->Compute (gates, gates);
239
+ vmul_d_->Compute (gates, gates + d_, gates + d_);
240
+ vmul_d_->Compute (ct_1, gates + d2_, gates + d2_);
241
+ vadd_d_->Compute (gates + d_, gates + d2_, ct);
242
+
243
+ /* H_t = act_cell(C_t) * ogated */
244
+ act_cell_d_->Compute (ct, gates + d2_);
245
+ vmul_d_->Compute (gates + d2_, gates + d3_, ht);
246
+ }
202
247
203
- #define JITKERNEL_KEY_LSTM (ker_key, dtype_key ) \
204
- #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell
248
+ private:
249
+ int d_, d2_, d3_;
250
+ std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
251
+ std::shared_ptr<const VMulKernel<T>> vmul_d_;
252
+ std::shared_ptr<const VAddKernel<T>> vadd_d_;
253
+ };
254
+
255
+ #define JITKERNEL_DECLARE_LSTM (ker_class, ker_dtype ) \
256
+ template <> \
257
+ std::shared_ptr<const LSTMKernel<ker_dtype>> \
258
+ KernelPool::Get<LSTMKernel<ker_dtype>, const std::string&, \
259
+ const std::string&, const std::string&, int , bool >( \
260
+ const std::string& act_gate, const std::string& act_cand, \
261
+ const std::string& act_cell, int d, bool use_peephole)
205
262
206
- #define JITKERNEL_NEW_LSTM_IMPL (ker, dtype, isa, k ) \
207
- p = std::dynamic_pointer_cast<ker<dtype>>( \
208
- std::make_shared<ker##Impl<dtype, isa, k>>(d, act_gate, act_cand, \
209
- act_cell))
263
+ #define JITKERNEL_KEY_LSTM (ker_key, dtype_key ) \
264
+ #ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \
265
+ (use_peephole ? " p" : " n" )
266
+
267
+ #define JITKERNEL_NEW_LSTM_IMPL (ker, dtype, isa, k ) \
268
+ if (use_peephole) { \
269
+ p = std::dynamic_pointer_cast<ker<dtype>>( \
270
+ std::make_shared<PeepholeKernelImpl<dtype, isa, k>>( \
271
+ act_gate, act_cand, act_cell, d)); \
272
+ } else { \
273
+ p = std::dynamic_pointer_cast<ker<dtype>>( \
274
+ std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_cand, \
275
+ act_cell, d)); \
276
+ }
210
277
211
278
REGISTER_JITKERNEL_ARGS (lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
212
279
JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL);
@@ -215,7 +282,6 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
215
282
#undef JITKERNEL_DECLARE_LSTM
216
283
#undef JITKERNEL_KEY_LSTM
217
284
#undef JITKERNEL_NEW_LSTM_IMPL
218
-
219
285
} // namespace jitkernel
220
286
} // namespace math
221
287
} // namespace operators
0 commit comments