Skip to content

Commit 4bd9b5e

Browse files
Integrate KleidiAI BF16 SME2 Kernel Through Mlas SBGEMM Path
Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com>
1 parent 5f087c4 commit 4bd9b5e

File tree

12 files changed

+809
-23
lines changed

12 files changed

+809
-23
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ function(setup_kleidiai)
284284
target_sources(onnxruntime_mlas PRIVATE
285285
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp
286286
${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp
287+
${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp
287288
${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp
288289
${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp
289290
)
@@ -550,6 +551,23 @@ else()
550551
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
551552
endif()
552553

554+
#
555+
# Mac arm64 (M-series) supports BF16 instructions, and ORT's SBGEMM
556+
# entrypoints (MlasSBGemm*) are provided by the AArch64 SBGEMM sources.
557+
# Add minimal changes required to enable sbgemm on this platform
558+
#
559+
if (APPLE AND CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
560+
set(mlas_platform_srcs
561+
${mlas_platform_srcs}
562+
${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
563+
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
564+
${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp
565+
)
566+
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
567+
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
568+
set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
569+
endif()
570+
553571
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
554572
onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs})
555573
set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64")

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,6 +1967,7 @@ struct MLAS_SBGEMM_DATA_PARAMS {
19671967
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
19681968
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
19691969
bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */
1970+
bool BIsPacked = false; /**< Whether B is pre-packed */
19701971
};
19711972

19721973
/**
@@ -1976,6 +1977,8 @@ struct MLAS_SBGEMM_DATA_PARAMS {
19761977
* Note: We only support uniform batching, so shapes and types of the
19771978
* input must be same across all parameter blocks.
19781979
*
1980+
* @param[in] TransA Supplies the transpose operation for matrix A.
1981+
* @param[in] TransB Supplies the transpose operation for matrix B.
19791982
* @param[in] M row size of matrix A and C
19801983
* @param[in] N column size of matrix B and C
19811984
* @param[in] K column size of matrix A and row size of matrix B
@@ -1985,31 +1988,61 @@ struct MLAS_SBGEMM_DATA_PARAMS {
19851988
* @return
19861989
*/
19871990
void MLASCALL
1988-
MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool);
1991+
MlasSBGemmBatch(
1992+
CBLAS_TRANSPOSE TransA,
1993+
CBLAS_TRANSPOSE TransB,
1994+
const size_t M,
1995+
const size_t N,
1996+
const size_t K,
1997+
const size_t BatchN,
1998+
const MLAS_SBGEMM_DATA_PARAMS* DataParams,
1999+
MLAS_THREADPOOL* ThreadPool = nullptr
2000+
);
19892001

19902002
/**
19912003
* @brief For bfloat16 precision GEMM, returns size of the
19922004
* packing buffer needed for right hand side
1993-
* @param[in] N Number of columns
1994-
* @param[in] K Number of rows
2005+
* @param[in] TransA Supplies the transpose operation for matrix A.
2006+
* @param[in] TransB Supplies the transpose operation for matrix B.
2007+
* @param[in] BIsfp32 Is matrix B datatype FP32
2008+
* @param[in] N Number of columns
2009+
* @param[in] K Number of rows
19952010
* @return size of the packing buffer,
19962011
* 0 if operation not supported
19972012
*/
19982013
size_t MLASCALL
1999-
MlasSBGemmPackBSize(size_t N, size_t K);
2014+
MlasSBGemmPackBSize(
2015+
CBLAS_TRANSPOSE TransA,
2016+
CBLAS_TRANSPOSE TransB,
2017+
bool BIsfp32,
2018+
size_t N,
2019+
size_t K
2020+
);
20002021

20012022
/**
20022023
* @brief For bfloat16 precision GEMM, convert the float matrix B
20032024
* to blfoat16 precision and pack it into a packing buffer
20042025
*
2026+
* @param[in] TransA Supplies the transpose operation for matrix A.
2027+
* @param[in] TransB Supplies the transpose operation for matrix B.
2028+
* @param[in] BIsfp32 Is matrix B datatype FP32
20052029
* @param[in] N Number of columns
20062030
* @param[in] K Number of rows
20072031
* @param[in] B Address of matrix B
20082032
* @param[in] ldb leading dimension of input matrix B
20092033
* @param[out] PackedB Address of the packed matrix
20102034
*/
20112035
void MLASCALL
2012-
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
2036+
MlasSBGemmConvertPackB(
2037+
CBLAS_TRANSPOSE TransA,
2038+
CBLAS_TRANSPOSE TransB,
2039+
bool BIsfp32,
2040+
size_t N,
2041+
size_t K,
2042+
const float* B,
2043+
size_t ldb,
2044+
void* PackedB
2045+
);
20132046
#endif
20142047

20152048
/**

onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
3131

3232
// SME2 kernels
33-
// GEMM/QGEMM
33+
// GEMM/QGEMM/SBGEMM
3434
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
3535
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
36+
#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
37+
3638
// GEMV
3739
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
3840
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
@@ -227,6 +229,9 @@ const KaiF32IMatmulKernel imatmul_conv_sme =
227229
const KaiF32IMatmulKernel imatmul_conv_sme2 =
228230
KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa);
229231

232+
const KaiBF16SBgemmKernel sbgemm_gemm_sme2 =
233+
KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa);
234+
230235
#if defined(ENABLE_QMX_KERNELS)
231236
const KaiF32IMatmulKernel imatmul_conv_qmx =
232237
KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_qmx_mopa);
@@ -362,3 +367,8 @@ const KaiDynamicQGemmKernel& GetKleidiAIQGemmUKernel() {
362367
#endif // ENABLE_QMX_KERNELS
363368
}
364369
}
370+
371+
const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel() {
372+
// Currently only SME2 variant exists for bfloat16/SBGEMM kernel
373+
return sbgemm_gemm_sme2;
374+
}

onnxruntime/core/mlas/lib/kai_ukernel_interface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h"
1414

15+
#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"
16+
1517
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
1618

1719
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h"
@@ -39,6 +41,8 @@ using KaiDynamicQGemmKernel = KaiMatmulKernel<kai_matmul_clamp_f32_qai8dxp_qsi8c
3941
// Wrapper for FP32 IMATMUL kernels used by the KleidiAI convolution implementation.
4042
using KaiF32IMatmulKernel = KaiMatmulKernel<kai_imatmul_clamp_f32_f32p_f32p_ukernel>;
4143

44+
using KaiBF16SBgemmKernel = KaiMatmulKernel<kai_matmul_clamp_f32_bf16p_bf16p_ukernel>;
45+
4246
// Returns the selected Qnbit GEMM ukernel based on runtime CPU capabilities.
4347
const KaiQnbitGemmKernel& GetKleidiAIGemmUKernel();
4448

@@ -56,3 +60,6 @@ const KaiF32SgemvKernel& GetKleidiAISGemvUKernel();
5660

5761
// Returns the selected FP32 IMATMUL ukernel used by the KleidiAI convolution implementation.
5862
const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel();
63+
64+
// Returns the selected BF16 SBGEMM ukernel used by the KleidiAI based on runtime CPU capabilities.
65+
const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel();

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,42 @@ MlasGemmBatch(
105105
MLAS_THREADPOOL* ThreadPool
106106
);
107107

108+
#if defined(__aarch64__) && defined(__linux__)
109+
size_t
110+
MLASCALL
111+
MlasSBGemmPackBSize(
112+
CBLAS_TRANSPOSE TransA,
113+
CBLAS_TRANSPOSE TransB,
114+
size_t N,
115+
size_t K
116+
);
117+
118+
bool
119+
MLASCALL
120+
MlasSBGemmPackB(
121+
CBLAS_TRANSPOSE TransA,
122+
CBLAS_TRANSPOSE TransB,
123+
size_t N,
124+
size_t K,
125+
const float* B,
126+
size_t ldb,
127+
void* PackedB
128+
);
129+
130+
bool
131+
MLASCALL
132+
MlasSBGemmBatch(
133+
CBLAS_TRANSPOSE TransA,
134+
CBLAS_TRANSPOSE TransB,
135+
size_t M,
136+
size_t N,
137+
size_t K,
138+
const MLAS_SBGEMM_DATA_PARAMS* Data,
139+
size_t BatchSize,
140+
MLAS_THREADPOOL* ThreadPool
141+
);
142+
#endif
143+
108144
size_t
109145
MLASCALL
110146
MlasDynamicQGemmPackBSize(

0 commit comments

Comments
 (0)