Skip to content

Commit 5125ed2

Browse files
JonathanC-ARMfs-eire
authored andcommitted
Reworked sgemm_kleidi memory allocations to reuse memory buffers (#26166)
### **Key changes** This PR makes changes to KleidiAI integration within the existing sgemm_kleidiai.cpp implementation. It was noted that during internal testing that memory allocation overhead due to repeated allocations of vectors was having a negative impact on performance figures. The changes introduce thread local buffers for reusing memory during inference. Android platforms are particularly sensitive to this, we have observed inference times being significantly impacted due to memory allocation overheads ### Example performance All runs were captured using onnxruntime_perf_test e.g. onnxruntime_perf_test -v -e cpu -I -m times -x 1 -y 1 -r 1000 **Android Platform** <img width="996" height="286" alt="image" src="https://github.com/user-attachments/assets/252165af-c864-4b24-b1f2-c28ada208b06" /> In addition to this on M4 we have also observed slight improvements on models, however its the gain is not as significant as the allocation overhead is lower in terms of total time on that platform **Mac Mini M4** <img width="741" height="153" alt="image" src="https://github.com/user-attachments/assets/93e6c545-96fd-4bfc-b90f-3a845a1551bc" /> **Onnxruntime Mlas Benchmark** Mlas Benchmark was executed on a Mac Mini M4 with SME2 instructions Tested code with and without changes in pr and observed the following results (subset shown) comparison generated using compare.py located in google benchmark repo tools `./onnxruntime_mlas_benchmark --benchmark_filter="SGEMM/NORMAL*" --benchmark_repetitions=100` ``` Benchmark Time CPU Time Old Time New CPU Old CPU New -------------------------------------------------------------------------------------------------------------------------------------------------- SGEMM/NORMAL_NoTrans/M:63/N:63/K:63/real_time -0.1897 -0.1897 3270 2650 3270 2650 SGEMM/NORMAL_NoTrans/M:255/N:63/K:63/real_time -0.1468 -0.1469 8383 7152 8382 7151 SGEMM/NORMAL_NoTrans/M:1023/N:63/K:63/real_time -0.1506 -0.1506 19072 16200 19072 16200 SGEMM/NORMAL_NoTrans/M:63/N:255/K:63/real_time -0.1957 -0.1957 7742 6227 7742 6227 SGEMM/NORMAL_NoTrans/M:255/N:255/K:63/real_time -0.1032 -0.1032 14323 12845 14322 12845 SGEMM/NORMAL_TransB/M:63/N:63/K:63/real_time -0.2221 -0.2221 3356 2611 3356 2610 SGEMM/NORMAL_TransB/M:255/N:63/K:63/real_time -0.0439 -0.0438 8602 8224 8601 8224 SGEMM/NORMAL_TransB/M:1023/N:63/K:63/real_time +0.0436 +0.0436 16488 17206 16487 17206 SGEMM/NORMAL_TransB/M:63/N:255/K:63/real_time -0.2000 -0.1999 8046 6437 8046 6437 SGEMM/NORMAL_TransB/M:255/N:255/K:63/real_time -0.0979 -0.0979 14131 12747 14130 12747 SGEMM/NORMAL_TransB/M:1023/N:255/K:63/real_time -0.2836 -0.2836 62540 44802 62540 44802 SGEMM/NORMAL_TransB/M:63/N:1023/K:63/real_time -0.2183 -0.2183 15342 11993 15342 ``` Some small regressions have been seen but are difficult to explain, suspected machine variance during run could account for things like ``` SGEMM/NORMAL_TransB/M:1023/N:63/K:63/real_time +0.0436 +0.0436 16488 17206 16487 17206 ``` For example, as part of testing these results sgemm_kleidi.cpp was instrumented (after the previous benchmark results) with timer code, in MlasGemmBatch, MlasGemmPackB, and MlasGemmPackBSize. Which produced the following, indicating that the code performs better in this case on average than baseline which is currently in main ``` Head of main Function Count Avg (ns) Avg (pretty) ---------------------------------------------------------- MlasGemmBatch 42664 19601.015 19.601 us MlasGemmPackB 42664 373.943 373.943 ns MlasGemmPackBSize 42664 17.179 17.179 ns TLB changes Function Count Avg (ns) Avg (pretty) ---------------------------------------------------------- MlasGemmBatch 55492 16985.256 16.985 us MlasGemmPackB 55492 344.800 344.800 ns MlasGemmPackBSize 55492 16.788 16.788 ns ``` --------- Signed-off-by: Jonathan Clohessy <[email protected]>
1 parent 92791ab commit 5125ed2

File tree

2 files changed

+96
-38
lines changed

2 files changed

+96
-38
lines changed

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,37 @@ MlasConv(
115115
MLAS_THREADPOOL* ThreadPool
116116
);
117117
}
118+
119+
/*++
120+
121+
Routine Description:
122+
123+
This routine determines if a wraparound will occur when multiplying two size_t variables
124+
Uses __builtin_mul_overflow if available on the current system and if not falls back
125+
to a default implementation to check this wraparound.
126+
127+
Arguments:
128+
129+
a - Supplies the first number to be muliplied.
130+
131+
b - Supplies the second number to be muliplied.
132+
133+
out - pointer to a size_t which acts as the return value in success cases.
134+
135+
Return Value:
136+
137+
Returns false if the operation was successful
138+
Returns true if wraparound of size_t was detected
139+
140+
--*/
141+
inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) {
142+
#if defined(__has_builtin)
143+
# if __has_builtin(__builtin_mul_overflow)
144+
return __builtin_mul_overflow(a, b, out);
145+
# endif
146+
#endif
147+
// Fallback to manual check if builtin not available
148+
if (b != 0 && a > SIZE_MAX / b) return true;
149+
if (out) *out = a * b;
150+
return false;
151+
}

onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp

Lines changed: 62 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"
1515
#include "mlasi_kleidiai.h"
1616

17+
18+
// Thread-local reusable buffers to reduce allocation overhead across tiles.
19+
struct KaiTlsBuffers {
20+
std::vector<float> output_tile;
21+
std::vector<float> bias_zero;
22+
std::vector<std::byte> rhs_packed;
23+
std::vector<std::byte> lhs_packed;
24+
};
25+
static thread_local KaiTlsBuffers g_kai_tls;
26+
1727
size_t
1828
MLASCALL
1929
ArmKleidiAI::MlasGemmPackBSize(
@@ -51,7 +61,6 @@ Return Value:
5161
// Compute the number of bytes required to hold the packed buffer.
5262
//
5363
size_t bytes = 0;
54-
5564
if (TransA == CblasNoTrans) {
5665
switch (TransB) {
5766
case CblasNoTrans:
@@ -125,15 +134,15 @@ Return Value:
125134
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
126135
: kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
127136

128-
// pass zeroed bias values
129-
const std::vector<float> bias(N);
137+
// Ensure size and zero the used span.
138+
g_kai_tls.bias_zero.resize(N, 0.0f);
130139

131140
switch (TransB) {
132141
case CblasNoTrans:
133-
kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr);
142+
kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr);
134143
break;
135144
case CblasTrans:
136-
kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr);
145+
kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr);
137146
break;
138147
default:
139148
return false;
@@ -225,59 +234,61 @@ Return Value:
225234
size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
226235
: kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
227236

228-
if (M < m_step && N < n_step && !Data->BIsPacked) {
237+
if ((M < m_step || N < n_step) && !Data->BIsPacked) {
229238
// Fallback to MLAS
230239
return false;
231240
}
232241

233-
std::vector<MLAS_SGEMM_DATA_PARAMS> KaiPackedData;
234-
KaiPackedData.resize(BatchSize);
235-
236242
size_t LhsPackedStride = 0;
237243
std::byte* LhsPackedData = nullptr;
238244

239245
LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr);
240-
auto LhsPacked = std::make_unique<std::byte[]>(LhsPackedStride * BatchSize);
241-
LhsPackedData = LhsPacked.get();
242246

243-
std::unique_ptr<std::byte[]> RhsPacked{nullptr};
247+
size_t lhs_resize = 0;
248+
if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize))
249+
{
250+
// size_t wraparound detected for LhsPackedStride, fallback to MLAS
251+
return false;
252+
}
253+
254+
g_kai_tls.lhs_packed.resize(lhs_resize);
255+
LhsPackedData = g_kai_tls.lhs_packed.data();
256+
257+
// RHS packed buffer: use TLS reusable vector to minimize allocations
258+
size_t RhsPackedStride = 0;
259+
std::byte* RhsPackedData = nullptr;
244260

245261
// It is assumed all B batches require packing or not
246262
if (Data[0].BIsPacked) {
247263
// We have already decided the matmul variant we are using, before having values for M,N,K
248264
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {
249265
std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);
250266
kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);
251-
KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
252-
KaiPackedData[batch_idx].B = Data[batch_idx].B;
253267
});
254268
} else {
255269
// Multithread pack lhs and rhs
256-
size_t RhsPackedStride = 0;
257-
std::byte* RhsPackedData = nullptr;
258-
259270
RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K);
260-
RhsPacked = std::make_unique<std::byte[]>(RhsPackedStride * BatchSize);
261-
RhsPackedData = RhsPacked.get();
271+
size_t rhs_resize = 0;
272+
if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize))
273+
{
274+
// size_t wraparound detected for RhsPackedStride, fallback to MLAS
275+
return false;
276+
}
277+
278+
g_kai_tls.rhs_packed.resize(rhs_resize);
279+
RhsPackedData = g_kai_tls.rhs_packed.data();
262280

263281
MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) {
264-
// lhs odd, rhs even
265282
if (batch_idx & 0x1) {
266283
batch_idx >>= 1;
267-
268284
std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);
269-
270285
kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);
271-
272-
KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
273286
} else {
274287
batch_idx >>= 1;
275-
276288
std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]);
277-
278-
ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast<const float*>(Data[batch_idx].B), Data[batch_idx].ldb, RhsPackedPtr);
279-
280-
KaiPackedData[batch_idx].B = reinterpret_cast<const float*>(RhsPackedPtr);
289+
ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K,
290+
reinterpret_cast<const float*>(Data[batch_idx].B),
291+
Data[batch_idx].ldb, RhsPackedPtr);
281292
}
282293
});
283294
}
@@ -303,6 +314,14 @@ Return Value:
303314
dim[1] = MlasDivRoundup(M, m_step);
304315
dim[2] = MlasDivRoundup(N, n_step);
305316

317+
// Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop.
318+
// Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively.
319+
size_t max_tile_elems = 0;
320+
if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) {
321+
// size_t wraparound detected for tile size, fallback to MLAS
322+
return false;
323+
}
324+
306325
MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) {
307326
// compute B,M,N index from iteration index
308327
ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
@@ -314,18 +333,18 @@ Return Value:
314333
UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K)
315334
: kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K);
316335

317-
auto BTile = reinterpret_cast<const void*>(
318-
reinterpret_cast<const std::byte*>(KaiPackedData[BIdx].B) + rhs_packed_offset
319-
);
336+
const std::byte* B_base = Data[0].BIsPacked
337+
? reinterpret_cast<const std::byte*>(Data[BIdx].B)
338+
: (RhsPackedData + RhsPackedStride * BIdx);
339+
auto BTile = reinterpret_cast<const void*>(B_base + rhs_packed_offset);
320340

321341
// Get lhs tile, A
322342
const size_t lhs_packed_offset =
323343
UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K)
324344
: kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K);
325345

326-
auto ATile = reinterpret_cast<const float*>(
327-
reinterpret_cast<const std::byte*>(KaiPackedData[BIdx].A) + lhs_packed_offset
328-
);
346+
const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx;
347+
auto ATile = reinterpret_cast<const float*>(A_base + lhs_packed_offset);
329348

330349
auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step;
331350
auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step;
@@ -336,9 +355,14 @@ Return Value:
336355
MIdx * m_step * Data[BIdx].ldc * sizeof(float) +
337356
NIdx * n_step * sizeof(float)
338357
);
339-
// Allocate temporary buffer for raw A*B result
340-
std::vector<float> OutputTile(TileSizeM * TileSizeN, 0.0f);
341-
float* temp_tile = OutputTile.data();
358+
// Allocate temporary buffer for raw A*B result (TLS reusable buffer)
359+
size_t tile_elems = TileSizeM * TileSizeN;
360+
361+
// resize the tile to the required size
362+
g_kai_tls.output_tile.resize(tile_elems);
363+
364+
float* temp_tile = g_kai_tls.output_tile.data();
365+
std::fill_n(temp_tile, tile_elems, 0.0f);
342366

343367
if (UseSME2) {
344368
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(

0 commit comments

Comments
 (0)