Skip to content

Commit 631389c

Browse files
committed
feat: Added higher leading dimension tests
1 parent 30e988c commit 631389c

14 files changed

+1426
-53
lines changed

src/test/BaseGeneration.test.cpp

Lines changed: 94 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
#include "BaseGeneration.test.h"
22
#include "kernels/matmul.test.h"
3+
#include <cmath>
34

45
void GenerationTest::fill_random_matrix(float *matrix, uint32_t size)
56
{
67
std::srand(std::time(0));
78
for (size_t i = 0; i < size; i++)
89
{
9-
matrix[i] = (static_cast<float>(std::rand())) / (static_cast<float>(std::rand()));
10+
float denominator = 1;
11+
do
12+
{
13+
denominator = static_cast<float>(std::rand());
14+
} while (denominator == 0);
15+
16+
float numerator = 1;
17+
do
18+
{
19+
numerator = static_cast<float>(std::rand());
20+
} while (numerator == 0);
21+
22+
matrix[i] = numerator / denominator;
1023
}
1124
}
1225

@@ -41,14 +54,27 @@ void GenerationTest::verify_matmul(const float *__restrict__ expected, const flo
4154
for (size_t i = 0; i < size; i++)
4255
{
4356
CAPTURE(i, result[i], expected[i]);
44-
REQUIRE_THAT(result[i], Catch::Matchers::WithinRel(expected[i]));
57+
58+
if (std::isnan(expected[i]))
59+
{
60+
REQUIRE_THAT(result[i], Catch::Matchers::IsNaN());
61+
}
62+
else
63+
{
64+
REQUIRE_THAT(result[i], Catch::Matchers::WithinRel(expected[i]));
65+
}
4566
}
4667
}
4768

4869
GenerationTest::GenerationTest(uint32_t M, uint32_t N, uint32_t K) : GenerationTest(M, N, K, 1)
4970
{
5071
}
5172

73+
GenerationTest::GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t lda, uint32_t ldb, uint32_t ldc)
74+
: GenerationTest(M, N, K, 1, lda, ldb, ldc, lda * K, ldb * N)
75+
{
76+
}
77+
5278
GenerationTest::GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t BatchSize) : M(M), N(N), K(K), BatchSize(BatchSize)
5379
{
5480

@@ -58,6 +84,16 @@ GenerationTest::GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t Batc
5884
matrix_c_verify = new float[M * N];
5985
}
6086

87+
GenerationTest::GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t BatchSize, uint32_t lda, uint32_t ldb, uint32_t ldc,
88+
uint32_t batch_stride_a, uint32_t batch_stride_b)
89+
: M(M), N(N), K(K), BatchSize(BatchSize), lda(lda), ldb(ldb), ldc(ldc), batch_stride_a(batch_stride_a), batch_stride_b(batch_stride_b)
90+
{
91+
matrix_a = new float[batch_stride_a * BatchSize];
92+
matrix_b = new float[batch_stride_b * BatchSize];
93+
matrix_c = new float[ldc * N];
94+
matrix_c_verify = new float[ldc * N];
95+
}
96+
6197
GenerationTest::~GenerationTest()
6298
{
6399
delete[] matrix_a;
@@ -68,24 +104,46 @@ GenerationTest::~GenerationTest()
68104

69105
void GenerationTest::SetUp(TestInfill fillType)
70106
{
71-
switch (fillType)
107+
if (lda != 0)
72108
{
73-
case TestInfill::Random:
74-
fill_random_matrix(matrix_a, M * K * BatchSize);
75-
fill_random_matrix(matrix_b, K * N * BatchSize);
76-
fill_random_matrix(matrix_c, M * N);
77-
break;
78-
case TestInfill::Counting:
79-
fill_counting_matrix(matrix_a, M * K * BatchSize);
80-
fill_counting_matrix(matrix_b, K * N * BatchSize);
81-
fill_counting_matrix(matrix_c, M * N);
82-
break;
83-
default:
84-
FAIL("Undefined infill type found.");
85-
break;
109+
switch (fillType)
110+
{
111+
case TestInfill::Random:
112+
fill_random_matrix(matrix_a, batch_stride_a * BatchSize);
113+
fill_random_matrix(matrix_b, batch_stride_b * BatchSize);
114+
fill_random_matrix(matrix_c, ldc * N);
115+
break;
116+
case TestInfill::Counting:
117+
fill_counting_matrix(matrix_a, batch_stride_a * BatchSize);
118+
fill_counting_matrix(matrix_b, batch_stride_b * BatchSize);
119+
fill_counting_matrix(matrix_c, ldc * N);
120+
break;
121+
default:
122+
FAIL("Undefined infill type found.");
123+
break;
124+
}
125+
std::copy(matrix_c, matrix_c + ldc * N, matrix_c_verify);
126+
}
127+
else
128+
{
129+
switch (fillType)
130+
{
131+
case TestInfill::Random:
132+
fill_random_matrix(matrix_a, M * K * BatchSize);
133+
fill_random_matrix(matrix_b, K * N * BatchSize);
134+
fill_random_matrix(matrix_c, M * N);
135+
break;
136+
case TestInfill::Counting:
137+
fill_counting_matrix(matrix_a, M * K * BatchSize);
138+
fill_counting_matrix(matrix_b, K * N * BatchSize);
139+
fill_counting_matrix(matrix_c, M * N);
140+
break;
141+
default:
142+
FAIL("Undefined infill type found.");
143+
break;
144+
}
145+
std::copy(matrix_c, matrix_c + M * N, matrix_c_verify);
86146
}
87-
88-
std::copy(matrix_c, matrix_c + M * N, matrix_c_verify);
89147
}
90148

91149
void GenerationTest::SetKernel(mini_jit::Brgemm::kernel_t kernel)
@@ -102,10 +160,27 @@ void GenerationTest::RunTest(const uint32_t lda, const uint32_t ldb, const uint3
102160
FAIL("The kernel should be set before the test is executed.");
103161
}
104162

163+
if (GenerationTest::lda != 0)
164+
{
165+
// Verification of same lda, ldb, batch_stride_a, batch_stride_b
166+
REQUIRE(GenerationTest::lda == lda);
167+
REQUIRE(GenerationTest::ldb == ldb);
168+
REQUIRE(GenerationTest::ldc == ldc);
169+
REQUIRE(GenerationTest::batch_stride_a == batch_stride_a);
170+
REQUIRE(GenerationTest::batch_stride_b == batch_stride_b);
171+
}
172+
105173
// Run matmuls
106174
kernel(matrix_a, matrix_b, matrix_c, lda, ldb, ldc, batch_stride_a, batch_stride_b);
107175

108176
naive_matmul_M_N_K_Batch(matrix_a, matrix_b, matrix_c_verify, lda, ldb, ldc, batch_stride_a, batch_stride_b);
109177

110-
verify_matmul(matrix_c_verify, matrix_c, M * N);
178+
if (lda != 0)
179+
{
180+
verify_matmul(matrix_c_verify, matrix_c, ldc * N);
181+
}
182+
else
183+
{
184+
verify_matmul(matrix_c_verify, matrix_c, M * N);
185+
}
111186
}

src/test/BaseGeneration.test.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ class GenerationTest
1212
uint32_t N;
1313
uint32_t K;
1414
uint32_t BatchSize;
15+
uint32_t lda = 0; // lda != 0 Use as indicator if any leading dimension higher than a size is used
16+
uint32_t ldb = 0;
17+
uint32_t ldc = 0;
18+
uint32_t batch_stride_a = 0;
19+
uint32_t batch_stride_b = 0;
1520
float *matrix_a;
1621
float *matrix_b;
1722
float *matrix_c;
@@ -61,7 +66,10 @@ class GenerationTest
6166
public:
6267
GenerationTest() = delete;
6368
GenerationTest(uint32_t M, uint32_t N, uint32_t K);
69+
GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t lda, uint32_t ldb, uint32_t ldc);
6470
GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t BatchSize);
71+
GenerationTest(uint32_t M, uint32_t N, uint32_t K, uint32_t BatchSize, uint32_t lda, uint32_t ldb, uint32_t ldc, uint32_t batch_stride_a,
72+
uint32_t batch_stride_b);
6573
~GenerationTest();
6674

6775
/**

src/test/Brgemm.test.cpp

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],lda=M, ldb=K, and ldc=M) on random data",
88
"[generation][correctness][gemm]")
99
{
10-
auto M = GENERATE(range(1u, 64u + 1u, 1u)); // TODO Replace with this if matmuls implemented
11-
auto N = GENERATE(range(1u, 64u + 1u, 1u)); // TODO Replace with this if matmuls implemented
10+
auto M = GENERATE(range(1u, 64u + 1u, 1u));
11+
auto N = GENERATE(range(1u, 64u + 1u, 1u));
1212
auto K = GENERATE(1u, 16u, 32u, 64u, 128u);
1313

1414
CAPTURE(M, N, K);
@@ -44,3 +44,132 @@ TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],ld
4444
generatorTest.SetKernel(kernel);
4545
generatorTest.RunTest(M, K, M, 0, 0);
4646
}
47+
48+
TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],lda=M, ldb=K, and ldc=M) on counting data",
49+
"[generation][correctness][gemm]")
50+
{
51+
auto M = GENERATE(range(1u, 64u + 1u, 1u));
52+
auto N = GENERATE(range(1u, 64u + 1u, 1u));
53+
auto K = GENERATE(1u, 16u, 32u, 64u, 128u);
54+
55+
CAPTURE(M, N, K);
56+
57+
GenerationTest generatorTest(M, N, K);
58+
generatorTest.SetUp(TestInfill::Counting);
59+
60+
mini_jit::Brgemm gemm;
61+
mini_jit::Brgemm::error_t error = gemm.generate(M, N, K, 1, 0, 0, 0, mini_jit::Brgemm::dtype_t::fp32);
62+
63+
switch (error)
64+
{
65+
case mini_jit::Brgemm::error_t::success:
66+
break;
67+
case mini_jit::Brgemm::error_t::err_batch_reduce_size_not_supported:
68+
FAIL("Error batch reduce size not supported.");
69+
break;
70+
case mini_jit::Brgemm::error_t::err_row_major_order_not_supported:
71+
FAIL("Error row major order not supported.");
72+
break;
73+
case mini_jit::Brgemm::error_t::err_wrong_dimension:
74+
FAIL("Error err wrong dimension.");
75+
break;
76+
case mini_jit::Brgemm::error_t::err_wrong_dtype:
77+
FAIL("Error wrong dtype.");
78+
break;
79+
default:
80+
FAIL("Found unprocessed error type");
81+
break;
82+
}
83+
84+
mini_jit::Brgemm::kernel_t kernel = gemm.get_kernel();
85+
generatorTest.SetKernel(kernel);
86+
generatorTest.RunTest(M, K, M, 0, 0);
87+
}
88+
89+
TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],lda>M, ldb>K, and ldc>M) on random data",
90+
"[generation][correctness][gemm]")
91+
{
92+
auto M = GENERATE(range(1u, 64u + 1u, 1u));
93+
auto N = GENERATE(range(1u, 64u + 1u, 1u));
94+
auto K = GENERATE(1u, 16u, 32u, 64u, 128u);
95+
const uint32_t lda = M + 5;
96+
const uint32_t ldb = K + 3;
97+
const uint32_t ldc = M + 7;
98+
99+
CAPTURE(M, N, K, lda, ldb, ldc);
100+
101+
GenerationTest generatorTest(M, N, K, lda, ldb, ldc);
102+
generatorTest.SetUp(TestInfill::Random);
103+
104+
mini_jit::Brgemm gemm;
105+
mini_jit::Brgemm::error_t error = gemm.generate(M, N, K, 1, 0, 0, 0, mini_jit::Brgemm::dtype_t::fp32);
106+
107+
switch (error)
108+
{
109+
case mini_jit::Brgemm::error_t::success:
110+
break;
111+
case mini_jit::Brgemm::error_t::err_batch_reduce_size_not_supported:
112+
FAIL("Error batch reduce size not supported.");
113+
break;
114+
case mini_jit::Brgemm::error_t::err_row_major_order_not_supported:
115+
FAIL("Error row major order not supported.");
116+
break;
117+
case mini_jit::Brgemm::error_t::err_wrong_dimension:
118+
FAIL("Error err wrong dimension.");
119+
break;
120+
case mini_jit::Brgemm::error_t::err_wrong_dtype:
121+
FAIL("Error wrong dtype.");
122+
break;
123+
default:
124+
FAIL("Found unprocessed error type");
125+
break;
126+
}
127+
128+
mini_jit::Brgemm::kernel_t kernel = gemm.get_kernel();
129+
generatorTest.SetKernel(kernel);
130+
generatorTest.RunTest(lda, ldb, ldc, lda * K, ldb * N);
131+
}
132+
133+
TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],lda>M, ldb>K, and ldc>M) on counting data",
134+
"[generation][correctness][gemm]")
135+
{
136+
auto M = GENERATE(range(1u, 64u + 1u, 1u));
137+
auto N = GENERATE(range(1u, 64u + 1u, 1u));
138+
auto K = GENERATE(1u, 16u, 32u, 64u, 128u);
139+
const uint32_t lda = M + 5;
140+
const uint32_t ldb = K + 3;
141+
const uint32_t ldc = M + 7;
142+
143+
CAPTURE(M, N, K, lda, ldb, ldc);
144+
145+
GenerationTest generatorTest(M, N, K, lda, ldb, ldc);
146+
generatorTest.SetUp(TestInfill::Counting);
147+
148+
mini_jit::Brgemm gemm;
149+
mini_jit::Brgemm::error_t error = gemm.generate(M, N, K, 1, 0, 0, 0, mini_jit::Brgemm::dtype_t::fp32);
150+
151+
switch (error)
152+
{
153+
case mini_jit::Brgemm::error_t::success:
154+
break;
155+
case mini_jit::Brgemm::error_t::err_batch_reduce_size_not_supported:
156+
FAIL("Error batch reduce size not supported.");
157+
break;
158+
case mini_jit::Brgemm::error_t::err_row_major_order_not_supported:
159+
FAIL("Error row major order not supported.");
160+
break;
161+
case mini_jit::Brgemm::error_t::err_wrong_dimension:
162+
FAIL("Error err wrong dimension.");
163+
break;
164+
case mini_jit::Brgemm::error_t::err_wrong_dtype:
165+
FAIL("Error wrong dtype.");
166+
break;
167+
default:
168+
FAIL("Found unprocessed error type");
169+
break;
170+
}
171+
172+
mini_jit::Brgemm::kernel_t kernel = gemm.get_kernel();
173+
generatorTest.SetKernel(kernel);
174+
generatorTest.RunTest(lda, ldb, ldc, lda * K, ldb * N);
175+
}

src/test/kernels/matmul.bench.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,12 @@ BENCHMARK_DEFINE_F(GemmFixture, BM_matmul)(benchmark::State &state)
5151

5252
static void CustomArguments(benchmark::internal::Benchmark *b)
5353
{
54-
for (int M = 16; M <= 64; M += 16)
55-
for (int N = 16; N <= 64; N += 16)
54+
for (int M = 1; M <= 64; M += 1)
55+
for (int N = 1; N <= 64; N += 1)
5656
for (int K : {1, 16, 32, 64, 128})
5757
b->Args({M, N, K});
5858
}
5959

60-
// ########## UNCOMMENT WHEN brgemm.generate() supports m, n < 16 ##########
61-
// static void CustomArguments(benchmark::internal::Benchmark *b)
62-
// {
63-
// for (int M = 1; M <= 64; M += 1)
64-
// for (int N = 1; N <= 64; N += 1)
65-
// for (int K : {1, 16, 32, 64, 128})
66-
// b->Args({M, N, K});
67-
// }
68-
6960
BENCHMARK_REGISTER_F(GemmFixture, BM_matmul)
7061
->ArgNames({"M", "N", "K"})
7162
->DisplayAggregatesOnly(true)

0 commit comments

Comments
 (0)