@@ -232,40 +232,28 @@ use lstm_x_t as input and compute as standard LSTM.
232
232
template <typename T>
233
233
inline void bias_relu (const int n, const T* x, const T* bias, T* y) {
234
234
if (bias) {
235
- for (int i = 0 ; i < n; ++i) {
236
- y[i] = x[i] + bias[0 ];
237
- }
238
- math::vec_relu<T>(n, y, y);
235
+ math::vec_add_bias<T, platform::jit::avx>(n, *bias, x, y);
236
+ math::vec_relu<T, platform::jit::avx>(n, y, y);
239
237
} else {
240
- math::vec_relu<T>(n, x, y);
238
+ math::vec_relu<T, platform::jit::avx >(n, x, y);
241
239
}
242
240
}
243
241
244
- template <typename DeviceContext, typename T>
245
- inline void vec_softmax (const math::BlasT<DeviceContext, T>& blas, const int n,
246
- const T* x, T* y) {
242
+ template <typename T>
243
+ inline void vec_softmax (const int n, const T* x, T* y) {
247
244
T scalar = x[0 ];
248
245
// max
249
246
for (int i = 1 ; i < n; ++i) {
250
247
scalar = scalar < x[i] ? x[i] : scalar;
251
248
}
252
-
253
- // sub
254
- for (int i = 0 ; i < n; ++i) {
255
- y[i] = x[i] - scalar;
256
- }
257
-
258
- // exp
259
- blas.VEXP (n, y, y);
260
-
249
+ math::vec_add_bias<T, platform::jit::avx>(n, -scalar, x, y); // sub
250
+ math::vec_exp<T>(n, y, y); // exp
261
251
// sum
262
252
scalar = T (0 );
263
253
for (int i = 0 ; i < n; ++i) {
264
254
scalar += y[i];
265
255
}
266
-
267
- // scale
268
- blas.SCAL (n, static_cast <T>(1 ) / scalar, y);
256
+ math::vec_scal<T>(n, static_cast <T>(1 ) / scalar, y); // scale
269
257
}
270
258
271
259
template <typename T>
@@ -311,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
311
299
PADDLE_ENFORCE_EQ (c0->dims ()[0 ], N, " C0 dims should be %d x %d." , N, D);
312
300
fc_out->Resize ({max_seq_len, 1 });
313
301
314
- math::VecActivations<T> act_functor;
315
302
std::function<void (const int , const T *, T *)> act_gate, act_cell, act_cand;
316
- act_gate = act_functor (ctx.Attr <std::string>(" gate_activation" ));
317
- act_cell = act_functor (ctx.Attr <std::string>(" cell_activation" ));
318
- act_cand = act_functor (ctx.Attr <std::string>(" candidate_activation" ));
303
+ auto & act_gate_str = ctx.Attr <std::string>(" gate_activation" );
304
+ auto & act_cell_str = ctx.Attr <std::string>(" cell_activation" );
305
+ auto & act_cand_str = ctx.Attr <std::string>(" candidate_activation" );
306
+ if (platform::jit::MayIUse (platform::jit::avx)) {
307
+ math::VecActivations<T, platform::jit::avx> act_functor;
308
+ act_gate = act_functor (act_gate_str);
309
+ act_cell = act_functor (act_cell_str);
310
+ act_cand = act_functor (act_cand_str);
311
+ } else {
312
+ math::VecActivations<T, platform::jit::isa_any> act_functor;
313
+ act_gate = act_functor (act_gate_str);
314
+ act_cell = act_functor (act_cell_str);
315
+ act_cand = act_functor (act_cand_str);
316
+ }
319
317
320
318
const T* x_data = x->data <T>();
321
319
const T* h0_data = h0 ? h0->data <T>() : NULL ;
@@ -363,7 +361,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
363
361
fc_out_data);
364
362
}
365
363
// 1d. softmax
366
- vec_softmax<DeviceContext, T>(blas, seq_len, fc_out_data, fc_out_data);
364
+ vec_softmax<T>(seq_len, fc_out_data, fc_out_data);
367
365
// mul x(seq_len*M) and sum pool
368
366
math::FCCompute<DeviceContext, T>(blas, 1 , M, seq_len, fc_out_data,
369
367
cur_x_data, lstm_x_data);
0 commit comments