Skip to content

Commit b55c247

Browse files
committed
add lstm compute unit test
1 parent 2a00969 commit b55c247

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,123 @@ TEST(JitKernel, vtanh) {
328328
}
329329
}
330330

331+
void lstm_ctht_ref(
332+
const std::shared_ptr<
333+
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
334+
vsigmoid_3d,
335+
const std::shared_ptr<
336+
const paddle::operators::math::jitkernel::VTanhKernel<float>>& vtanh_d,
337+
const std::shared_ptr<
338+
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
339+
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
340+
vsigmoid_3d->Compute(gates + d, gates + d);
341+
vtanh_d->Compute(gates, gates);
342+
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
343+
const float min = SIGMOID_THRESHOLD_MIN;
344+
const float max = SIGMOID_THRESHOLD_MAX;
345+
for (int k = 0; k < d; ++k) {
346+
// C_t = C_t-1 * fgated + cand_gated * igated
347+
ct[k] = ct_1[k] * f[k] + gates[k] * i[k];
348+
// H_t = act_cell(C_t) * ogated
349+
float tmp = ct[k] * 2;
350+
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
351+
vexp_1->Compute(&tmp, &tmp);
352+
tmp = 2.f / (1.f + tmp) - 1.f;
353+
ht[k] = tmp * o[k];
354+
}
355+
}
356+
357+
void lstm_ctht_better(
358+
const std::shared_ptr<
359+
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
360+
vsigmoid_3d,
361+
const std::shared_ptr<
362+
const paddle::operators::math::jitkernel::VTanhKernel<float>>& vtanh_d,
363+
const std::shared_ptr<
364+
const paddle::operators::math::jitkernel::VMulKernel<float>>& vmul_d,
365+
const std::shared_ptr<
366+
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
367+
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
368+
int d2 = d * 2;
369+
vsigmoid_3d->Compute(gates + d, gates + d);
370+
vtanh_d->Compute(gates, gates);
371+
vmul_d->Compute(gates, gates + d, gates + d);
372+
vmul_d->Compute(ct_1, gates + d2, gates + d2);
373+
vadd_d->Compute(gates + d, gates + d2, ct);
374+
/* H_t = act_cell(C_t) * ogated */
375+
vtanh_d->Compute(ct, gates + d2);
376+
vmul_d->Compute(gates + d2, gates + d * 3, ht);
377+
}
378+
379+
TEST(JitKernel, lstm) {
380+
namespace jit = paddle::operators::math::jitkernel;
381+
for (int d : {7, 8, 15, 16, 30, 32, 64, 100}) {
382+
int d4 = d * 4;
383+
int d3 = d * 3;
384+
std::vector<float> x(d4), xref(d4);
385+
std::vector<float> ct_1(d), ct_tgt(d), ht_tgt(d);
386+
std::vector<float> ct_ref(d), ht_ref(d);
387+
RandomVec<float>(d4, x.data(), -2.f, 2.f);
388+
RandomVec<float>(d, ct_1.data(), -2.f, 2.f);
389+
memcpy(xref.data(), x.data(), sizeof(float) * d4);
390+
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
391+
const auto& ker =
392+
jit::KernelPool::Instance()
393+
.template Get<jit::LSTMKernel<float>, int, const std::string&,
394+
const std::string&, const std::string&>(
395+
d, act_gate, act_cand, act_cell);
396+
// below kernels are used to compute refer
397+
const auto& vsigmoid_3d =
398+
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
399+
d3);
400+
const auto& vtanh_d =
401+
jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d);
402+
const auto& vexp_1 =
403+
jit::KernelPool::Instance().template Get<jit::VExpKernel<float>>(1);
404+
const auto& vmul_d =
405+
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d);
406+
const auto& vadd_d =
407+
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
408+
409+
float* x_data = x.data();
410+
float* xref_data = xref.data();
411+
const float* ct_1_data = ct_1.data();
412+
float* ct_tgt_data = ct_tgt.data();
413+
float* ht_tgt_data = ht_tgt.data();
414+
float* ct_ref_data = ct_ref.data();
415+
float* ht_ref_data = ht_ref.data();
416+
// compute once to check correctness
417+
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data,
418+
ct_ref_data, ht_ref_data);
419+
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
420+
for (int i = 0; i < d; ++i) {
421+
EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3);
422+
EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3);
423+
}
424+
425+
auto tmkls = GetCurrentUS();
426+
for (int i = 0; i < repeat; ++i) {
427+
lstm_ctht_better(vsigmoid_3d, vtanh_d, vmul_d, vadd_d, d, xref_data,
428+
ct_1_data, ct_ref_data, ht_ref_data);
429+
}
430+
auto tmkle = GetCurrentUS();
431+
auto trefs = GetCurrentUS();
432+
for (int i = 0; i < repeat; ++i) {
433+
lstm_ctht_ref(vsigmoid_3d, vtanh_d, vexp_1, d, xref_data, ct_1_data,
434+
ct_ref_data, ht_ref_data);
435+
}
436+
auto trefe = GetCurrentUS();
437+
auto ttgts = GetCurrentUS();
438+
for (int i = 0; i < repeat; ++i) {
439+
ker->ComputeCtHt(x_data, ct_1_data, ct_tgt_data, ht_tgt_data);
440+
}
441+
auto ttgte = GetCurrentUS();
442+
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
443+
<< " us, better(jit) takes: " << (tmkle - tmkls) / repeat
444+
<< " us, tgt takes: " << (ttgte - ttgts) / repeat;
445+
}
446+
}
447+
331448
void vscal_ref(const int n, const float a, const float* x, float* y) {
332449
for (int i = 0; i < n; ++i) {
333450
y[i] = a * x[i];

0 commit comments

Comments
 (0)