Skip to content

Commit 530fd92

Browse files
committed
fix: remaining batch stuff
1 parent c23eee0 commit 530fd92

File tree

5 files changed

+154
-10
lines changed

5 files changed

+154
-10
lines changed

CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ option(SAVE_JITS_TO_FILE "Saves the jitted kernels into a file if activated." OF
2828

2929
if(SAVE_JITS_TO_FILE)
3030
message(NOTICE "The saved kernels can be disassembled with: 'objdump -D -b binary -m aarch64 <inputFile> > <outputFile>'")
31-
add_compile_definitions(SAVE_JITS_TO_FILE)
31+
# set per target
3232
endif()
3333

3434
# ==============================================================
@@ -222,10 +222,12 @@ endforeach()
222222
# TARGETS
223223
# =============================================================
224224
add_executable(tests "${SOURCE_FILEPATHS}" "${TEST_FILEPATHS}")
225+
if(SAVE_JITS_TO_FILE)
226+
target_compile_definitions(tests PUBLIC SAVE_JITS_TO_FILE)
227+
endif(SAVE_JITS_TO_FILE)
225228
target_link_libraries(tests PRIVATE Catch2::Catch2WithMain)
226229

227230
add_executable(benchmarks "${SOURCE_FILEPATHS}" "${BENCH_FILEPATHS}")
228-
target_compile_definitions(benchmarks PUBLIC SAVE_JITS_TO_FILE=0)
229231
target_link_libraries(benchmarks benchmark::benchmark_main)
230232

231233
# ==============================================================

src/main/Brgemm.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,12 @@ mini_jit::Brgemm::error_t mini_jit::Brgemm::generate(uint32_t m, uint32_t n, uin
1919
{
2020
return error_t::err_row_major_order_not_supported;
2121
}
22-
if (br_size != 1)
23-
{
24-
return error_t::err_batch_reduce_size_not_supported;
25-
}
22+
2623
if (br_size == 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
2724
{
2825
fill_with_matmuls_no_batch_dim_column_major_fp32(m, n, k);
2926
}
30-
if (br_size > 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
27+
else if (br_size > 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
3128
{
3229
fill_with_matmuls_batch_dim_column_major_fp32(m, n, k, br_size);
3330
}
@@ -152,4 +149,7 @@ void mini_jit::Brgemm::fill_with_matmuls_batch_dim_column_major_fp32(uint32_t m,
152149
kernels::br_matmul_lt16_lt4nRest_k(native_kernel, n / 4, k, br_size, m % 16, n % 4);
153150
return;
154151
}
152+
153+
throw std::logic_error(
154+
std::format("Unhandled combination found for MxNxKxBatch matmul: m='{}', n='{}', k='{}', batch='{}'", m, n, k, br_size));
155155
}

src/main/kernels/br_matmul_16mRest_lt4nRest_k.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ void mini_jit::kernels::br_matmul_16mRest_lt4nRest_k(mini_jit::Kernel &kernel, c
1212
using namespace mini_jit::arm_instructions;
1313

1414
release_assert(m_loop_16 != 0, "Cannot proccess matrix with m loop of 0.");
15-
release_assert(n_loop_4 != 0, "Cannot proccess matrix with n loop of 0.");
1615
release_assert(k_loop != 0, "Cannot proccess matrix with k loop of 0.");
1716
release_assert(m_loop_rest != 0, "Cannot create a matrix with a rest of m equal to 0!");
1817
release_assert(m_loop_rest <= 15, "Cannot create a matrix with a rest of m larger than 15!");
@@ -106,7 +105,10 @@ void mini_jit::kernels::br_matmul_16mRest_lt4nRest_k(mini_jit::Kernel &kernel, c
106105
// ========================================================================================
107106
// Calculate m + rest but n is multiple of 4
108107
// ========================================================================================
109-
matmul_16mRest_4n_k(kernel, m_loop_16, n_loop_4, k_loop, m_loop_rest, false);
108+
if (n_loop_4 != 0)
109+
{
110+
matmul_16mRest_4n_k(kernel, m_loop_16, n_loop_4, k_loop, m_loop_rest, false);
111+
}
110112

111113
// Offset to the next matrix block
112114
// Here we want to start with the initial m value but n should be offset by the already calculated amount.

src/test/Brgemm.test.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,79 @@ TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],ld
173173
generatorTest.SetKernel(kernel);
174174
generatorTest.RunTest(lda, ldb, ldc, lda * K, ldb * N);
175175
}
176+
177+
TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128], 1≤BatchSize≤16, lda=M, ldb=K, and ldc=M) on random data",
178+
"[generation][correctness][gemm]")
179+
{
180+
auto M = GENERATE(range(1u, 64u + 1u, 1u));
181+
auto N = GENERATE(range(1u, 64u + 1u, 1u));
182+
auto K = GENERATE(1u, 16u, 32u, 64u, 128u);
183+
auto BatchSize = GENERATE(range(1u, 16u + 1u, 1u));
184+
185+
CAPTURE(M, N, K, BatchSize);
186+
187+
mini_jit::Brgemm gemm;
188+
mini_jit::Brgemm::error_t error = gemm.generate(M, N, K, BatchSize, 0, 0, 0, mini_jit::Brgemm::dtype_t::fp32);
189+
190+
switch (error)
191+
{
192+
case mini_jit::Brgemm::error_t::success:
193+
break;
194+
case mini_jit::Brgemm::error_t::err_batch_reduce_size_not_supported:
195+
FAIL("Error batch reduce size not supported.");
196+
break;
197+
case mini_jit::Brgemm::error_t::err_row_major_order_not_supported:
198+
FAIL("Error row major order not supported.");
199+
break;
200+
case mini_jit::Brgemm::error_t::err_wrong_dimension:
201+
FAIL("Error err wrong dimension.");
202+
break;
203+
case mini_jit::Brgemm::error_t::err_wrong_dtype:
204+
FAIL("Error wrong dtype.");
205+
break;
206+
default:
207+
FAIL("Found unprocessed error type");
208+
break;
209+
}
210+
211+
mini_jit::Brgemm::kernel_t kernel = gemm.get_kernel();
212+
REQUIRE(kernel != nullptr);
213+
}
214+
215+
TEST_CASE("Test gemm generation (1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128], 1≤BatchSize≤16, lda=M, ldb=K, and ldc=M) on counting data",
216+
"[generation][correctness][gemm]")
217+
{
218+
auto M = GENERATE(range(1u, 64u + 1u, 1u));
219+
auto N = GENERATE(range(1u, 64u + 1u, 1u));
220+
auto K = GENERATE(1u, 16u, 32u, 64u, 128u);
221+
auto BatchSize = GENERATE(range(1u, 16u + 1u, 1u));
222+
223+
CAPTURE(M, N, K, BatchSize);
224+
225+
mini_jit::Brgemm gemm;
226+
mini_jit::Brgemm::error_t error = gemm.generate(M, N, K, BatchSize, 0, 0, 0, mini_jit::Brgemm::dtype_t::fp32);
227+
228+
switch (error)
229+
{
230+
case mini_jit::Brgemm::error_t::success:
231+
break;
232+
case mini_jit::Brgemm::error_t::err_batch_reduce_size_not_supported:
233+
FAIL("Error batch reduce size not supported.");
234+
break;
235+
case mini_jit::Brgemm::error_t::err_row_major_order_not_supported:
236+
FAIL("Error row major order not supported.");
237+
break;
238+
case mini_jit::Brgemm::error_t::err_wrong_dimension:
239+
FAIL("Error err wrong dimension.");
240+
break;
241+
case mini_jit::Brgemm::error_t::err_wrong_dtype:
242+
FAIL("Error wrong dtype.");
243+
break;
244+
default:
245+
FAIL("Found unprocessed error type");
246+
break;
247+
}
248+
249+
mini_jit::Brgemm::kernel_t kernel = gemm.get_kernel();
250+
REQUIRE(kernel != nullptr);
251+
}

src/test/kernels/matmul.bench.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,68 @@ BENCHMARK_REGISTER_F(GemmFixture, BM_matmul)
6161
->ArgNames({"M", "N", "K"})
6262
->DisplayAggregatesOnly(true)
6363
->Apply(CustomArguments)
64-
->MinWarmUpTime(1.0); // WarmUp in seconds
64+
->MinWarmUpTime(1.0); // WarmUp in seconds
65+
66+
class BrGemmFixture : public benchmark::Fixture
67+
{
68+
public:
69+
std::vector<float> matrix_a, matrix_b, matrix_c;
70+
double flops;
71+
72+
void SetUp(::benchmark::State &state) override
73+
{
74+
flops = 0;
75+
76+
int M = state.range(0);
77+
int N = state.range(1);
78+
int K = state.range(2);
79+
int Batch = state.range(3);
80+
81+
matrix_a.resize(M * K * Batch);
82+
matrix_b.resize(K * N * Batch);
83+
matrix_c.resize(M * N * Batch);
84+
85+
fill_random_matrix_args(matrix_a.data(), M * K * Batch);
86+
fill_random_matrix_args(matrix_b.data(), K * N * Batch);
87+
fill_random_matrix_args(matrix_c.data(), M * N * Batch);
88+
}
89+
90+
void TearDown(::benchmark::State &state) override
91+
{
92+
state.counters["FLOPS"] = benchmark::Counter(flops, benchmark::Counter::kIsRate);
93+
}
94+
};
95+
96+
BENCHMARK_DEFINE_F(BrGemmFixture, BM_brMatmul)(benchmark::State &state)
97+
{
98+
int M = state.range(0);
99+
int N = state.range(1);
100+
int K = state.range(2);
101+
int Batch = state.range(3);
102+
103+
mini_jit::Brgemm brgemm;
104+
brgemm.generate(M, N, K, Batch, 0, 0, 0, mini_jit::Brgemm::dtype_t::fp32);
105+
auto kernel = brgemm.get_kernel();
106+
107+
for (auto _ : state)
108+
{
109+
kernel(matrix_a.data(), matrix_b.data(), matrix_c.data(), M, K, M, M * K, K * N);
110+
}
111+
112+
flops = M * N * K * Batch * 2 * state.iterations();
113+
}
114+
115+
static void CustomArgumentsBatch(benchmark::internal::Benchmark *b)
116+
{
117+
int Batch = 16;
118+
for (int M = 1; M <= 64; M += 1)
119+
for (int N = 1; N <= 64; N += 1)
120+
for (int K : {1, 16, 32, 64, 128})
121+
b->Args({M, N, K, Batch});
122+
}
123+
124+
BENCHMARK_REGISTER_F(BrGemmFixture, BM_brMatmul)
125+
->ArgNames({"M", "N", "K", "Batch"})
126+
->DisplayAggregatesOnly(true)
127+
->Apply(CustomArgumentsBatch)
128+
->MinWarmUpTime(0.3); // WarmUp in seconds

0 commit comments

Comments
 (0)