1+ #include " matmul.bench.h"
2+ #include " ../../main/Brgemm.h"
3+ #include < benchmark/benchmark.h>
4+
5+ class GemmFixture : public benchmark ::Fixture
6+ {
7+ public:
8+ std::vector<float > matrix_a, matrix_b, matrix_c;
9+ double flops;
10+
11+ void SetUp (::benchmark::State &state) override
12+ {
13+ flops = 0 ;
14+
15+ int M = state.range (0 );
16+ int N = state.range (1 );
17+ int K = state.range (2 );
18+
19+ matrix_a.resize (M * K);
20+ matrix_b.resize (K * N);
21+ matrix_c.resize (M * N);
22+
23+ fill_random_matrix_args (matrix_a.data (), M * K);
24+ fill_random_matrix_args (matrix_b.data (), K * N);
25+ fill_random_matrix_args (matrix_c.data (), M * N);
26+ }
27+
28+ void TearDown (::benchmark::State &state) override
29+ {
30+ state.counters [" FLOPS" ] = benchmark::Counter (flops, benchmark::Counter::kIsRate );
31+ }
32+ };
33+
34+ BENCHMARK_DEFINE_F (GemmFixture, BM_matmul)(benchmark::State &state)
35+ {
36+ int M = state.range (0 );
37+ int N = state.range (1 );
38+ int K = state.range (2 );
39+
40+ mini_jit::Brgemm brgemm;
41+ brgemm.generate (M, N, K, 1 , 0 , 0 , 0 , mini_jit::Brgemm::dtype_t ::fp32);
42+ auto kernel = brgemm.get_kernel ();
43+
44+ for (auto _ : state)
45+ {
46+ kernel (matrix_a.data (), matrix_b.data (), matrix_c.data (), M, 1 , M, 1 , 1 );
47+ }
48+
49+ flops = M * N * K * 2 * state.iterations ();
50+ }
51+
52+ static void CustomArguments (benchmark::internal::Benchmark *b)
53+ {
54+ for (int M = 16 ; M <= 64 ; M += 16 )
55+ for (int N = 16 ; N <= 64 ; N += 16 )
56+ for (int K : {1 , 16 , 32 , 64 , 128 })
57+ b->Args ({M, N, K});
58+ }
59+
60+ // ########## UNCOMMENT WHEN brgemm.generate() supports m, n < 16 ##########
61+ // static void CustomArguments(benchmark::internal::Benchmark *b)
62+ // {
63+ // for (int M = 1; M <= 64; M += 1)
64+ // for (int N = 1; N <= 64; N += 1)
65+ // for (int K : {1, 16, 32, 64, 128})
66+ // b->Args({M, N, K});
67+ // }
68+
69+ BENCHMARK_REGISTER_F (GemmFixture, BM_matmul)
70+ ->ArgNames({" M" , " N" , " K" })
71+ ->ReportAggregatesOnly(true )
72+ ->Apply(CustomArguments)
73+ ->MinWarmUpTime(1.0 ); // WarmUp in seconds
0 commit comments