Skip to content

Commit 1f0291a

Browse files
committed
add comments and follow comments
test=develop
1 parent 557229b commit 1f0291a

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

paddle/fluid/operators/math/jit_kernel_refer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
116116
return nullptr;
117117
}
118118

119+
// compute ct and ht
119120
template <typename T>
120121
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
121122
T* gates = reinterpret_cast<T*>(step->gates);
@@ -199,6 +200,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
199200
VMul(gates, gates + d2, ht, d);
200201
}
201202

203+
// compute the first part of GRU: ht = act_gate(r) * ht_1
202204
template <typename T>
203205
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
204206
// W: {W_update, W_reset; W_state}
@@ -210,6 +212,8 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
210212
VMul(ht_1, gates + attr->d, ht, attr->d);
211213
}
212214

215+
// compute the second part of GRU:
216+
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
213217
template <typename T>
214218
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
215219
T* gates = reinterpret_cast<T*>(step->gates);

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ TEST(JitKernel, vrelu) {
8686
vrelu_intri8(d, x_data, zref_data);
8787
}
8888
auto si1 = GetCurrentUS();
89-
VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
89+
VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat << " us";
9090
}
9191
#endif
9292
auto ttgts = GetCurrentUS();
@@ -96,7 +96,7 @@ TEST(JitKernel, vrelu) {
9696
auto ttgte = GetCurrentUS();
9797
VLOG(30) << "Vec size " << d
9898
<< ": refer takes: " << (trefe - trefs) / repeat
99-
<< " us, tgt takes: " << (ttgte - ttgts) / repeat;
99+
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
100100
for (int i = 0; i < d; ++i) {
101101
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
102102
}
@@ -129,7 +129,7 @@ TEST(JitKernel, vaddbias) {
129129

130130
VLOG(30) << "Vec size " << d
131131
<< ": refer takes: " << (trefe - trefs) / repeat
132-
<< " us, tgt takes: " << (ttgte - ttgts) / repeat;
132+
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
133133
for (int i = 0; i < d; ++i) {
134134
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
135135
}
@@ -182,7 +182,7 @@ TEST(JitKernel, vexp) {
182182
#else
183183
<< " us, "
184184
#endif
185-
<< "tgt takes: " << (ttgte - ttgts) / repeat;
185+
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
186186
for (int i = 0; i < d; ++i) {
187187
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
188188
}
@@ -238,7 +238,7 @@ TEST(JitKernel, vsigmoid) {
238238
VLOG(30) << "Vec size " << d
239239
<< ": refer takes: " << (trefe - trefs) / repeat
240240
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
241-
<< " us, tgt takes: " << (ttgte - ttgts) / repeat;
241+
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
242242
for (int i = 0; i < d; ++i) {
243243
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
244244
}
@@ -299,7 +299,7 @@ TEST(JitKernel, vtanh) {
299299
VLOG(30) << "Vec size " << d
300300
<< ": refer takes: " << (trefe - trefs) / repeat
301301
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
302-
<< " us, tgt takes: " << (ttgte - ttgts) / repeat;
302+
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
303303
for (int i = 0; i < d; ++i) {
304304
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
305305
}
@@ -400,7 +400,7 @@ TEST(JitKernel, lstm) {
400400
VLOG(30) << "Vec size " << d
401401
<< ": refer takes: " << (trefe - trefs) / repeat
402402
<< " us, better(jit) takes: " << (tmkle - tmkls) / repeat
403-
<< " us, tgt takes: " << (ttgte - ttgts) / repeat;
403+
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
404404
}
405405
}
406406

@@ -474,7 +474,7 @@ TEST(JitKernel, vscal) {
474474
}
475475
auto si3 = GetCurrentUS();
476476
VLOG(30) << "Vec size 8 intr takes: " << (si1 - si0) / repeat
477-
<< " us, inplace: " << (si3 - si2) / repeat;
477+
<< " us, inplace: " << (si3 - si2) / repeat << " us";
478478
}
479479
#endif
480480

@@ -498,7 +498,8 @@ TEST(JitKernel, vscal) {
498498
<< " us, "
499499
#endif
500500
<< "tgt takes: " << (ttgte - ttgts) / repeat
501-
<< "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat;
501+
<< "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat
502+
<< " us";
502503
for (int i = 0; i < d; ++i) {
503504
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
504505
}
@@ -573,7 +574,7 @@ TEST(JitKernel, vmul) {
573574
#else
574575
<< " us, "
575576
#endif
576-
<< "tgt takes: " << (ttgte - ttgts) / repeat;
577+
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
577578
for (int i = 0; i < d; ++i) {
578579
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
579580
}
@@ -648,7 +649,7 @@ TEST(JitKernel, vadd) {
648649
#else
649650
<< " us, "
650651
#endif
651-
<< "tgt takes: " << (ttgte - ttgts) / repeat;
652+
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
652653
for (int i = 0; i < d; ++i) {
653654
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
654655
}
@@ -701,7 +702,7 @@ TEST(JitKernel, vaddrelu) {
701702
VLOG(30) << "Vec size " << d
702703
<< ": refer takes: " << (trefe - trefs) / repeat
703704
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
704-
<< "tgt takes: " << (ttgte - ttgts) / repeat;
705+
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
705706
for (int i = 0; i < d; ++i) {
706707
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
707708
}

0 commit comments

Comments
 (0)