|
22 | 22 | #include "paddle/fluid/platform/device_tracer.h"
|
23 | 23 | #include "paddle/fluid/platform/place.h"
|
24 | 24 | #include "paddle/fluid/platform/port.h"
|
| 25 | +#include "paddle/fluid/platform/variant.h" // for UNUSED |
25 | 26 |
|
26 | 27 | DEFINE_int32(burning, 10, "Burning times.");
|
27 | 28 | DEFINE_int32(repeat, 3000, "Repeat times.");
|
28 | 29 | DEFINE_int32(max_size, 1000, "The Max size would be tested.");
|
| 30 | +DEFINE_string(filter, "", "The Benchmark name would be run."); |
| 31 | + |
| 32 | +class BenchJITKernel { |
| 33 | + public: |
| 34 | + BenchJITKernel() = default; |
| 35 | + virtual ~BenchJITKernel() = default; |
| 36 | + virtual void Run() = 0; |
| 37 | + virtual const char* Name() = 0; |
| 38 | + virtual const char* Dtype() = 0; |
| 39 | + virtual const char* Place() = 0; |
| 40 | +}; |
| 41 | + |
| 42 | +static std::vector<BenchJITKernel*> g_all_benchmarks; |
| 43 | + |
| 44 | +BenchJITKernel* InsertBenchmark(BenchJITKernel* b) { |
| 45 | + g_all_benchmarks.push_back(b); |
| 46 | + return b; |
| 47 | +} |
| 48 | + |
| 49 | +#define BENCH_JITKERNEL(name, dtype, place) \ |
| 50 | + class BenchJITKernel_##name##_##dtype##_##place##_ : public BenchJITKernel { \ |
| 51 | + public: \ |
| 52 | + const char* Name() override { return #name; } \ |
| 53 | + const char* Dtype() override { return #dtype; } \ |
| 54 | + const char* Place() override { return #place; } \ |
| 55 | + void Run() override; \ |
| 56 | + }; \ |
| 57 | + static auto inserted_##name##_##dtype##_##place##_ UNUSED = \ |
| 58 | + InsertBenchmark(new BenchJITKernel_##name##_##dtype##_##place##_()); \ |
| 59 | + void BenchJITKernel_##name##_##dtype##_##place##_::Run() |
| 60 | + |
| 61 | +#define BENCH_FP32_CPU(name) BENCH_JITKERNEL(name, FP32, CPU) |
| 62 | + |
| 63 | +void RUN_ALL_BENCHMARK() { |
| 64 | + for (auto p : g_all_benchmarks) { |
| 65 | + if (!FLAGS_filter.empty() && FLAGS_filter != p->Name()) { |
| 66 | + continue; |
| 67 | + } |
| 68 | + LOG(INFO) << "Benchmark " << p->Name() << "." << p->Dtype() << "." |
| 69 | + << p->Place(); |
| 70 | + p->Run(); |
| 71 | + } |
| 72 | +} |
29 | 73 |
|
30 | 74 | template <typename T>
|
31 | 75 | void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
|
@@ -228,49 +272,70 @@ void BenchMatMulKernel() {
|
228 | 272 | }
|
229 | 273 | }
|
230 | 274 |
|
| 275 | +using T = float; |
| 276 | +using PlaceType = paddle::platform::CPUPlace; |
| 277 | + |
| 278 | +// xyzn |
| 279 | +BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, PlaceType>(); } |
| 280 | + |
| 281 | +BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, PlaceType>(); } |
| 282 | + |
| 283 | +BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>(); } |
| 284 | + |
| 285 | +BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, PlaceType>(); } |
| 286 | + |
| 287 | +// axyn |
| 288 | +BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, PlaceType>(); } |
| 289 | + |
| 290 | +BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, PlaceType>(); } |
| 291 | + |
| 292 | +// xyn |
| 293 | +BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, PlaceType>(); } |
| 294 | + |
| 295 | +BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); } |
| 296 | + |
| 297 | +BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, PlaceType>(); } |
| 298 | + |
| 299 | +BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, PlaceType>(); } |
| 300 | + |
| 301 | +BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); } |
| 302 | + |
| 303 | +BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, PlaceType>(); } |
| 304 | + |
| 305 | +// lstm and peephole |
| 306 | +BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>(); } |
| 307 | + |
| 308 | +BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>(); } |
| 309 | + |
| 310 | +// gru functions |
| 311 | +BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); } |
| 312 | + |
| 313 | +BENCH_FP32_CPU(kGRUHtPart1) { |
| 314 | + BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>(); |
| 315 | +} |
| 316 | + |
| 317 | +BENCH_FP32_CPU(kGRUHtPart2) { |
| 318 | + BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>(); |
| 319 | +} |
| 320 | + |
| 321 | +// seq pool function |
| 322 | +BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>(); } |
| 323 | + |
| 324 | +// matmul |
| 325 | +BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, PlaceType>(); } |
| 326 | + |
231 | 327 | // Benchmark all jit kernels including jitcode, mkl and refer.
|
232 | 328 | // To use this tool, run command: ./benchmark [options...]
|
233 | 329 | // Options:
|
234 | 330 | // --burning: the burning time before count
|
235 | 331 | // --repeat: the repeat times
|
236 | 332 | // --max_size: the max size would be tested
|
| 333 | +// --filter: the bench name would be run |
237 | 334 | int main(int argc, char* argv[]) {
|
238 | 335 | gflags::ParseCommandLineFlags(&argc, &argv, true);
|
239 | 336 | google::InitGoogleLogging(argv[0]);
|
240 | 337 | LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
|
241 | 338 | << " times.";
|
242 |
| - using T = float; |
243 |
| - using PlaceType = paddle::platform::CPUPlace; |
244 |
| - // xyzn |
245 |
| - BenchXYZNKernel<jit::kVMul, T, PlaceType>(); |
246 |
| - BenchXYZNKernel<jit::kVAdd, T, PlaceType>(); |
247 |
| - BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>(); |
248 |
| - BenchXYZNKernel<jit::kVSub, T, PlaceType>(); |
249 |
| - |
250 |
| - // axyn |
251 |
| - BenchAXYNKernel<jit::kVScal, T, PlaceType>(); |
252 |
| - BenchAXYNKernel<jit::kVAddBias, T, PlaceType>(); |
253 |
| - |
254 |
| - // xyn |
255 |
| - BenchXYNKernel<jit::kVRelu, T, PlaceType>(); |
256 |
| - BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); |
257 |
| - BenchXYNKernel<jit::kVSquare, T, PlaceType>(); |
258 |
| - BenchXYNKernel<jit::kVExp, T, PlaceType>(); |
259 |
| - BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); |
260 |
| - BenchXYNKernel<jit::kVTanh, T, PlaceType>(); |
261 |
| - |
262 |
| - // lstm and peephole |
263 |
| - BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>(); |
264 |
| - BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>(); |
265 |
| - |
266 |
| - // gru functions |
267 |
| - BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); |
268 |
| - BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>(); |
269 |
| - BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>(); |
270 |
| - |
271 |
| - // seq pool function |
272 |
| - BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>(); |
273 | 339 |
|
274 |
| - // matmul |
275 |
| - BenchMatMulKernel<jit::kMatMul, T, PlaceType>(); |
| 340 | + RUN_ALL_BENCHMARK(); |
276 | 341 | }
|
0 commit comments