@@ -328,6 +328,123 @@ TEST(JitKernel, vtanh) {
328
328
}
329
329
}
330
330
331
+ void lstm_ctht_ref (
332
+ const std::shared_ptr<
333
+ const paddle::operators::math::jitkernel::VSigmoidKernel<float >>&
334
+ vsigmoid_3d,
335
+ const std::shared_ptr<
336
+ const paddle::operators::math::jitkernel::VTanhKernel<float >>& vtanh_d,
337
+ const std::shared_ptr<
338
+ const paddle::operators::math::jitkernel::VExpKernel<float >>& vexp_1,
339
+ const int d, float * gates, const float * ct_1, float * ct, float * ht) {
340
+ vsigmoid_3d->Compute (gates + d, gates + d);
341
+ vtanh_d->Compute (gates, gates);
342
+ const float *i = gates + d, *f = gates + d * 2 , *o = gates + d * 3 ;
343
+ const float min = SIGMOID_THRESHOLD_MIN;
344
+ const float max = SIGMOID_THRESHOLD_MAX;
345
+ for (int k = 0 ; k < d; ++k) {
346
+ // C_t = C_t-1 * fgated + cand_gated * igated
347
+ ct[k] = ct_1[k] * f[k] + gates[k] * i[k];
348
+ // H_t = act_cell(C_t) * ogated
349
+ float tmp = ct[k] * 2 ;
350
+ tmp = 0 .f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
351
+ vexp_1->Compute (&tmp, &tmp);
352
+ tmp = 2 .f / (1 .f + tmp) - 1 .f ;
353
+ ht[k] = tmp * o[k];
354
+ }
355
+ }
356
+
357
+ void lstm_ctht_better (
358
+ const std::shared_ptr<
359
+ const paddle::operators::math::jitkernel::VSigmoidKernel<float >>&
360
+ vsigmoid_3d,
361
+ const std::shared_ptr<
362
+ const paddle::operators::math::jitkernel::VTanhKernel<float >>& vtanh_d,
363
+ const std::shared_ptr<
364
+ const paddle::operators::math::jitkernel::VMulKernel<float >>& vmul_d,
365
+ const std::shared_ptr<
366
+ const paddle::operators::math::jitkernel::VAddKernel<float >>& vadd_d,
367
+ const int d, float * gates, const float * ct_1, float * ct, float * ht) {
368
+ int d2 = d * 2 ;
369
+ vsigmoid_3d->Compute (gates + d, gates + d);
370
+ vtanh_d->Compute (gates, gates);
371
+ vmul_d->Compute (gates, gates + d, gates + d);
372
+ vmul_d->Compute (ct_1, gates + d2, gates + d2);
373
+ vadd_d->Compute (gates + d, gates + d2, ct);
374
+ /* H_t = act_cell(C_t) * ogated */
375
+ vtanh_d->Compute (ct, gates + d2);
376
+ vmul_d->Compute (gates + d2, gates + d * 3 , ht);
377
+ }
378
+
379
+ TEST (JitKernel, lstm) {
380
+ namespace jit = paddle::operators::math::jitkernel;
381
+ for (int d : {7 , 8 , 15 , 16 , 30 , 32 , 64 , 100 }) {
382
+ int d4 = d * 4 ;
383
+ int d3 = d * 3 ;
384
+ std::vector<float > x (d4), xref (d4);
385
+ std::vector<float > ct_1 (d), ct_tgt (d), ht_tgt (d);
386
+ std::vector<float > ct_ref (d), ht_ref (d);
387
+ RandomVec<float >(d4, x.data (), -2 .f , 2 .f );
388
+ RandomVec<float >(d, ct_1.data (), -2 .f , 2 .f );
389
+ memcpy (xref.data (), x.data (), sizeof (float ) * d4);
390
+ std::string act_gate = " sigmoid" , act_cand = " tanh" , act_cell = " tanh" ;
391
+ const auto & ker =
392
+ jit::KernelPool::Instance ()
393
+ .template Get <jit::LSTMKernel<float >, int , const std::string&,
394
+ const std::string&, const std::string&>(
395
+ d, act_gate, act_cand, act_cell);
396
+ // below kernels are used to compute refer
397
+ const auto & vsigmoid_3d =
398
+ jit::KernelPool::Instance ().template Get <jit::VSigmoidKernel<float >>(
399
+ d3);
400
+ const auto & vtanh_d =
401
+ jit::KernelPool::Instance ().template Get <jit::VTanhKernel<float >>(d);
402
+ const auto & vexp_1 =
403
+ jit::KernelPool::Instance ().template Get <jit::VExpKernel<float >>(1 );
404
+ const auto & vmul_d =
405
+ jit::KernelPool::Instance ().template Get <jit::VMulKernel<float >>(d);
406
+ const auto & vadd_d =
407
+ jit::KernelPool::Instance ().template Get <jit::VAddKernel<float >>(d);
408
+
409
+ float * x_data = x.data ();
410
+ float * xref_data = xref.data ();
411
+ const float * ct_1_data = ct_1.data ();
412
+ float * ct_tgt_data = ct_tgt.data ();
413
+ float * ht_tgt_data = ht_tgt.data ();
414
+ float * ct_ref_data = ct_ref.data ();
415
+ float * ht_ref_data = ht_ref.data ();
416
+ // compute once to check correctness
417
+ lstm_ctht_ref (vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data,
418
+ ct_ref_data, ht_ref_data);
419
+ ker->ComputeCtHt (x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
420
+ for (int i = 0 ; i < d; ++i) {
421
+ EXPECT_NEAR (ct_tgt_data[i], ct_ref_data[i], 1e-3 );
422
+ EXPECT_NEAR (ht_tgt_data[i], ht_ref_data[i], 1e-3 );
423
+ }
424
+
425
+ auto tmkls = GetCurrentUS ();
426
+ for (int i = 0 ; i < repeat; ++i) {
427
+ lstm_ctht_better (vsigmoid_3d, vtanh_d, vmul_d, vadd_d, d, xref_data,
428
+ ct_1_data, ct_ref_data, ht_ref_data);
429
+ }
430
+ auto tmkle = GetCurrentUS ();
431
+ auto trefs = GetCurrentUS ();
432
+ for (int i = 0 ; i < repeat; ++i) {
433
+ lstm_ctht_ref (vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data,
434
+ ct_ref_data, ht_ref_data);
435
+ }
436
+ auto trefe = GetCurrentUS ();
437
+ auto ttgts = GetCurrentUS ();
438
+ for (int i = 0 ; i < repeat; ++i) {
439
+ ker->ComputeCtHt (x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
440
+ }
441
+ auto ttgte = GetCurrentUS ();
442
+ VLOG (3 ) << " Vec size " << d << " : refer takes: " << (trefe - trefs) / repeat
443
+ << " us, better(jit) takes: " << (tmkle - tmkls) / repeat
444
+ << " us, tgt takes: " << (ttgte - ttgts) / repeat;
445
+ }
446
+ }
447
+
331
448
void vscal_ref (const int n, const float a, const float * x, float * y) {
332
449
for (int i = 0 ; i < n; ++i) {
333
450
y[i] = a * x[i];
0 commit comments