Skip to content

Commit 266e625

Browse files
authored
Merge pull request #15399 from tensor-tang/refine/seqpool/fc
fix cpu jitkernel test and refine benchmark test
2 parents 885c4e5 + 316e44b commit 266e625

File tree

2 files changed

+118
-43
lines changed

2 files changed

+118
-43
lines changed

paddle/fluid/operators/jit/benchmark.cc

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,54 @@
2222
#include "paddle/fluid/platform/device_tracer.h"
2323
#include "paddle/fluid/platform/place.h"
2424
#include "paddle/fluid/platform/port.h"
25+
#include "paddle/fluid/platform/variant.h" // for UNUSED
2526

2627
DEFINE_int32(burning, 10, "Burning times.");
2728
DEFINE_int32(repeat, 3000, "Repeat times.");
2829
DEFINE_int32(max_size, 1000, "The Max size would be tested.");
30+
DEFINE_string(filter, "", "The Benchmark name would be run.");
31+
32+
class BenchJITKernel {
33+
public:
34+
BenchJITKernel() = default;
35+
virtual ~BenchJITKernel() = default;
36+
virtual void Run() = 0;
37+
virtual const char* Name() = 0;
38+
virtual const char* Dtype() = 0;
39+
virtual const char* Place() = 0;
40+
};
41+
42+
static std::vector<BenchJITKernel*> g_all_benchmarks;
43+
44+
BenchJITKernel* InsertBenchmark(BenchJITKernel* b) {
45+
g_all_benchmarks.push_back(b);
46+
return b;
47+
}
48+
49+
#define BENCH_JITKERNEL(name, dtype, place) \
50+
class BenchJITKernel_##name##_##dtype##_##place##_ : public BenchJITKernel { \
51+
public: \
52+
const char* Name() override { return #name; } \
53+
const char* Dtype() override { return #dtype; } \
54+
const char* Place() override { return #place; } \
55+
void Run() override; \
56+
}; \
57+
static auto inserted_##name##_##dtype##_##place##_ UNUSED = \
58+
InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \
59+
void BenchJITKernel_##name##_##dtype##_##place##_::Run()
60+
61+
#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU)
62+
63+
void RUN_ALL_BENCHMARK() {
64+
for (auto p : g_all_benchmarks) {
65+
if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) {
66+
continue;
67+
}
68+
LOG(INFO) << "Benchmark " << p->Name() << "." << p->Dtype() << "."
69+
<< p->Place();
70+
p->Run();
71+
}
72+
}
2973

3074
template <typename T>
3175
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
@@ -228,49 +272,70 @@ void BenchMatMulKernel() {
228272
}
229273
}
230274

275+
using T = float;
276+
using PlaceType = paddle::platform::CPUPlace;
277+
278+
// xyzn
279+
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, PlaceType>(); }
280+
281+
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, PlaceType>(); }
282+
283+
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>(); }
284+
285+
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, PlaceType>(); }
286+
287+
// axyn
288+
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, PlaceType>(); }
289+
290+
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, PlaceType>(); }
291+
292+
// xyn
293+
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, PlaceType>(); }
294+
295+
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); }
296+
297+
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, PlaceType>(); }
298+
299+
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, PlaceType>(); }
300+
301+
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); }
302+
303+
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, PlaceType>(); }
304+
305+
// lstm and peephole
306+
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>(); }
307+
308+
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>(); }
309+
310+
// gru functions
311+
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); }
312+
313+
BENCH_FP32_CPU(kGRUHtPart1) {
314+
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
315+
}
316+
317+
BENCH_FP32_CPU(kGRUHtPart2) {
318+
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
319+
}
320+
321+
// seq pool function
322+
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>(); }
323+
324+
// matmul
325+
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, PlaceType>(); }
326+
231327
// Benchmark all jit kernels including jitcode, mkl and refer.
232328
// To use this tool, run command: ./benchmark [options...]
233329
// Options:
234330
// --burning: the burning time before count
235331
// --repeat: the repeat times
236332
// --max_size: the max size would be tested
333+
// --filter: the bench name would be run
237334
int main(int argc, char* argv[]) {
238335
gflags::ParseCommandLineFlags(&argc, &argv, true);
239336
google::InitGoogleLogging(argv[0]);
240337
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
241338
<< " times.";
242-
using T = float;
243-
using PlaceType = paddle::platform::CPUPlace;
244-
// xyzn
245-
BenchXYZNKernel<jit::kVMul, T, PlaceType>();
246-
BenchXYZNKernel<jit::kVAdd, T, PlaceType>();
247-
BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>();
248-
BenchXYZNKernel<jit::kVSub, T, PlaceType>();
249-
250-
// axyn
251-
BenchAXYNKernel<jit::kVScal, T, PlaceType>();
252-
BenchAXYNKernel<jit::kVAddBias, T, PlaceType>();
253-
254-
// xyn
255-
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
256-
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
257-
BenchXYNKernel<jit::kVSquare, T, PlaceType>();
258-
BenchXYNKernel<jit::kVExp, T, PlaceType>();
259-
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
260-
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
261-
262-
// lstm and peephole
263-
BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>();
264-
BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>();
265-
266-
// gru functions
267-
BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
268-
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
269-
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
270-
271-
// seq pool function
272-
BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
273339

274-
// matmul
275-
BenchMatMulKernel<jit::kMatMul, T, PlaceType>();
340+
RUN_ALL_BENCHMARK();
276341
}

paddle/fluid/operators/jit/test.cc

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "paddle/fluid/platform/cpu_info.h"
2323
#include "paddle/fluid/platform/place.h"
2424

25+
static double acc = 1e-5;
26+
2527
template <typename T>
2628
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
2729
const T upper = static_cast<T>(20.f)) {
@@ -37,7 +39,7 @@ template <typename T>
3739
void ExpectEQ(const T* target, const T* refer, int n) {
3840
if (std::is_floating_point<T>::value) {
3941
for (int i = 0; i < n; ++i) {
40-
EXPECT_NEAR(target[i], refer[i], 1e-5);
42+
EXPECT_NEAR(target[i], refer[i], acc);
4143
}
4244
} else {
4345
for (int i = 0; i < n; ++i) {
@@ -62,7 +64,9 @@ namespace jit = paddle::operators::jit;
6264

6365
template <typename KernelTuples, typename... Args>
6466
struct TestFuncWithRefer {
65-
void operator()(const typename KernelTuples::func_type tgt, Args... args) {}
67+
void operator()(const typename KernelTuples::func_type tgt, Args... args) {
68+
LOG(FATAL) << "Should specify this function.";
69+
}
6670
};
6771

6872
template <typename T>
@@ -140,7 +144,8 @@ struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> {
140144

141145
template <typename T>
142146
struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>,
143-
std::vector<T>, std::vector<T>, std::vector<T>> {
147+
std::vector<T>, std::vector<T>, std::vector<T>,
148+
typename jit::LSTMTuples<T>::attr_type> {
144149
void operator()(const typename jit::LSTMTuples<T>::func_type tgt,
145150
const std::vector<T>& xsrc, const std::vector<T>& wp,
146151
const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
@@ -185,7 +190,8 @@ struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>,
185190

186191
template <typename T>
187192
struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
188-
std::vector<T>> {
193+
std::vector<T>,
194+
typename jit::GRUTuples<T>::attr_type> {
189195
void operator()(const typename jit::GRUTuples<T>::func_type tgt,
190196
const std::vector<T>& xsrc, const std::vector<T>& ht_1,
191197
const std::vector<T>& ht_ref,
@@ -212,8 +218,8 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
212218
};
213219

214220
template <typename T>
215-
struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
216-
std::vector<T>> {
221+
struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>, std::vector<T>,
222+
typename jit::SeqPoolTuples<T>::attr_type> {
217223
void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt,
218224
const std::vector<T>& x, const std::vector<T>& yref,
219225
const typename jit::SeqPoolTuples<T>::attr_type& attr) {
@@ -385,8 +391,8 @@ void TestLSTMKernel() {
385391
std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
386392
std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
387393
RandomVec<T>(4 * d, xsrc.data(), -2.f, 2.f);
388-
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
389-
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
394+
RandomVec<T>(3 * d, wp.data(), -1.f, 1.f);
395+
RandomVec<T>(d, ct_1.data(), -1.f, 1.f);
390396
// x could be changed after compute, so copy to save src
391397
std::vector<T> x(xsrc.size());
392398
std::copy(xsrc.begin(), xsrc.end(), x.begin());
@@ -481,14 +487,17 @@ void TestSeqPoolKernel() {
481487
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
482488
void TestMatMulKernel() {
483489
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
490+
auto last_acc = acc;
491+
// TODO(intel): this should be acc issue of MKL
492+
acc = 1e-3;
484493
for (int m : {1, 2, 3, 4}) {
485494
for (int n : {1, 2, 3, 4}) {
486495
for (int k : TestSizes()) {
487496
auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
488497
EXPECT_TRUE(ref != nullptr);
489498
std::vector<T> a(m * k), b(k * n), c(m * n);
490-
RandomVec<T>(m * k, a.data(), -0.2f, 0.2f);
491-
RandomVec<T>(k * n, b.data(), -0.2f, 0.2f);
499+
RandomVec<T>(m * k, a.data(), -2.f, 2.f);
500+
RandomVec<T>(k * n, b.data(), -2.f, 2.f);
492501
const T* a_data = a.data();
493502
const T* b_data = b.data();
494503
T* c_data = c.data();
@@ -498,6 +507,7 @@ void TestMatMulKernel() {
498507
}
499508
}
500509
}
510+
acc = last_acc;
501511
}
502512

503513
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>

0 commit comments

Comments
 (0)