Skip to content

Commit 509c420

Browse files
Integrate KleidiAI BF16 SME2 Kernel Through Mlas SBGEMM Path
Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com>
1 parent 892b2f1 commit 509c420

File tree

11 files changed

+496
-0
lines changed

11 files changed

+496
-0
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ function(setup_kleidiai)
284284
target_sources(onnxruntime_mlas PRIVATE
285285
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp
286286
${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp
287+
${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp
287288
${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp
288289
${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp
289290
)

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,6 +1967,7 @@ struct MLAS_SBGEMM_DATA_PARAMS {
19671967
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
19681968
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
19691969
bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */
1970+
bool BIsPacked = false; /**< Whether B is pre-packed */
19701971
};
19711972

19721973
/**

onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
1414
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"
1515

16+
#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
17+
1618
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h"
1719
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
1820
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
1921
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"
2022
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
2123
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
24+
2225
#if defined(ENABLE_QMX_KERNELS)
2326
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa.h"
2427
#endif // ENABLE_QMX_KERNELS
@@ -125,6 +128,19 @@ const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme2 =
125128
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
126129
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa};
127130

131+
const kai_matmul_clamp_f32_bf16p_bf16p_ukernel sbgemm_gemm_sme2 =
132+
{kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
133+
kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
134+
kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
135+
kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
136+
kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
137+
kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
138+
kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
139+
kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
140+
kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
141+
kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
142+
kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa};
143+
128144
#if defined(ENABLE_QMX_KERNELS)
129145
const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_qmx =
130146
{kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa,
@@ -181,3 +197,8 @@ const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() {
181197
return sgemm_gemv_sme;
182198
}
183199
}
200+
201+
const kai_matmul_clamp_f32_bf16p_bf16p_ukernel& GetKleidiAISBGemmUKernel() {
202+
// Currently only SME2 variant exists for bfloat16/SBGEMM kernel
203+
return sbgemm_gemm_sme2;
204+
}

onnxruntime/core/mlas/lib/kai_ukernel_interface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212

1313
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h"
1414

15+
#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"
16+
1517
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel();
1618
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel();
1719

1820
const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel();
1921
const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel();
22+
23+
const kai_matmul_clamp_f32_bf16p_bf16p_ukernel& GetKleidiAISBGemmUKernel();

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,36 @@ MlasGemmBatch(
105105
MLAS_THREADPOOL* ThreadPool
106106
);
107107

108+
#if defined(__aarch64__) && defined(__linux__)
109+
size_t
110+
MLASCALL
111+
MlasSBGemmPackBSize(
112+
size_t N,
113+
size_t K
114+
);
115+
116+
bool
117+
MLASCALL
118+
MlasSBGemmPackB(
119+
size_t N,
120+
size_t K,
121+
const float* B,
122+
size_t ldb,
123+
void* PackedB
124+
);
125+
126+
bool
127+
MLASCALL
128+
MlasSBGemmBatch(
129+
size_t M,
130+
size_t N,
131+
size_t K,
132+
const MLAS_SBGEMM_DATA_PARAMS* Data,
133+
size_t BatchSize,
134+
MLAS_THREADPOOL* ThreadPool
135+
);
136+
#endif
137+
108138
size_t
109139
MLASCALL
110140
MlasDynamicQgemmPackBSize(

0 commit comments

Comments
 (0)