Skip to content

Commit 120eaa1

Browse files
Integrate KleidiAI BF16 SME2 Kernel Through Mlas SBGEMM Path
Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com>
1 parent 4612613 commit 120eaa1

File tree

12 files changed

+809
-23
lines changed

12 files changed

+809
-23
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ function(setup_kleidiai)
284284
target_sources(onnxruntime_mlas PRIVATE
285285
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp
286286
${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp
287+
${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp
287288
${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp
288289
${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp
289290
)
@@ -550,6 +551,23 @@ else()
550551
set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
551552
endif()
552553

554+
#
555+
# Mac arm64 (M-series) supports BF16 instructions, and ORT's SBGEMM
556+
# entrypoints (MlasSBGemm*) are provided by the AArch64 SBGEMM sources.
557+
# Add minimal changes required to enable sbgemm on this platform
558+
#
559+
if (APPLE AND CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
560+
set(mlas_platform_srcs
561+
${mlas_platform_srcs}
562+
${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
563+
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
564+
${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp
565+
)
566+
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
567+
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
568+
set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
569+
endif()
570+
553571
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
554572
onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs})
555573
set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64")

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,6 +2006,7 @@ struct MLAS_SBGEMM_DATA_PARAMS {
20062006
bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/
20072007
bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/
20082008
bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */
2009+
bool BIsPacked = false; /**< Whether B is pre-packed */
20092010
};
20102011

20112012
/**
@@ -2015,6 +2016,8 @@ struct MLAS_SBGEMM_DATA_PARAMS {
20152016
* Note: We only support uniform batching, so shapes and types of the
20162017
* input must be same across all parameter blocks.
20172018
*
2019+
* @param[in] TransA Supplies the transpose operation for matrix A.
2020+
* @param[in] TransB Supplies the transpose operation for matrix B.
20182021
* @param[in] M row size of matrix A and C
20192022
* @param[in] N column size of matrix B and C
20202023
* @param[in] K column size of matrix A and row size of matrix B
@@ -2024,31 +2027,61 @@ struct MLAS_SBGEMM_DATA_PARAMS {
20242027
* @return
20252028
*/
20262029
void MLASCALL
2027-
MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool);
2030+
MlasSBGemmBatch(
2031+
CBLAS_TRANSPOSE TransA,
2032+
CBLAS_TRANSPOSE TransB,
2033+
const size_t M,
2034+
const size_t N,
2035+
const size_t K,
2036+
const size_t BatchN,
2037+
const MLAS_SBGEMM_DATA_PARAMS* DataParams,
2038+
MLAS_THREADPOOL* ThreadPool = nullptr
2039+
);
20282040

20292041
/**
20302042
* @brief For bfloat16 precision GEMM, returns size of the
20312043
* packing buffer needed for right hand side
2032-
* @param[in] N Number of columns
2033-
* @param[in] K Number of rows
2044+
* @param[in] TransA Supplies the transpose operation for matrix A.
2045+
* @param[in] TransB Supplies the transpose operation for matrix B.
2046+
* @param[in] BIsfp32 Is matrix B datatype FP32
2047+
* @param[in] N Number of columns
2048+
* @param[in] K Number of rows
20342049
* @return size of the packing buffer,
20352050
* 0 if operation not supported
20362051
*/
20372052
size_t MLASCALL
2038-
MlasSBGemmPackBSize(size_t N, size_t K);
2053+
MlasSBGemmPackBSize(
2054+
CBLAS_TRANSPOSE TransA,
2055+
CBLAS_TRANSPOSE TransB,
2056+
bool BIsfp32,
2057+
size_t N,
2058+
size_t K
2059+
);
20392060

20402061
/**
20412062
* @brief For bfloat16 precision GEMM, convert the float matrix B
20422063
* to blfoat16 precision and pack it into a packing buffer
20432064
*
2065+
* @param[in] TransA Supplies the transpose operation for matrix A.
2066+
* @param[in] TransB Supplies the transpose operation for matrix B.
2067+
* @param[in] BIsfp32 Is matrix B datatype FP32
20442068
* @param[in] N Number of columns
20452069
* @param[in] K Number of rows
20462070
* @param[in] B Address of matrix B
20472071
* @param[in] ldb leading dimension of input matrix B
20482072
* @param[out] PackedB Address of the packed matrix
20492073
*/
20502074
void MLASCALL
2051-
MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB);
2075+
MlasSBGemmConvertPackB(
2076+
CBLAS_TRANSPOSE TransA,
2077+
CBLAS_TRANSPOSE TransB,
2078+
bool BIsfp32,
2079+
size_t N,
2080+
size_t K,
2081+
const float* B,
2082+
size_t ldb,
2083+
void* PackedB
2084+
);
20522085
#endif
20532086

20542087
/**

onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
3131

3232
// SME2 kernels
33-
// GEMM/QGEMM
33+
// GEMM/QGEMM/SBGEMM
3434
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
3535
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
36+
#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
37+
3638
// GEMV
3739
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
3840
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
@@ -227,6 +229,9 @@ const KaiF32IMatmulKernel imatmul_conv_sme =
227229
const KaiF32IMatmulKernel imatmul_conv_sme2 =
228230
KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa);
229231

232+
const KaiBF16SBgemmKernel sbgemm_gemm_sme2 =
233+
KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa);
234+
230235
#if defined(ENABLE_QMX_KERNELS)
231236
const KaiF32IMatmulKernel imatmul_conv_qmx =
232237
KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_qmx_mopa);
@@ -362,3 +367,8 @@ const KaiDynamicQGemmKernel& GetKleidiAIQGemmUKernel() {
362367
#endif // ENABLE_QMX_KERNELS
363368
}
364369
}
370+
371+
const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel() {
372+
// Currently only SME2 variant exists for bfloat16/SBGEMM kernel
373+
return sbgemm_gemm_sme2;
374+
}

onnxruntime/core/mlas/lib/kai_ukernel_interface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h"
1414

15+
#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"
16+
1517
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
1618

1719
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h"
@@ -39,6 +41,8 @@ using KaiDynamicQGemmKernel = KaiMatmulKernel<kai_matmul_clamp_f32_qai8dxp_qsi8c
3941
// Wrapper for FP32 IMATMUL kernels used by the KleidiAI convolution implementation.
4042
using KaiF32IMatmulKernel = KaiMatmulKernel<kai_imatmul_clamp_f32_f32p_f32p_ukernel>;
4143

44+
using KaiBF16SBgemmKernel = KaiMatmulKernel<kai_matmul_clamp_f32_bf16p_bf16p_ukernel>;
45+
4246
// Returns the selected Qnbit GEMM ukernel based on runtime CPU capabilities.
4347
const KaiQnbitGemmKernel& GetKleidiAIGemmUKernel();
4448

@@ -56,3 +60,6 @@ const KaiF32SgemvKernel& GetKleidiAISGemvUKernel();
5660

5761
// Returns the selected FP32 IMATMUL ukernel used by the KleidiAI convolution implementation.
5862
const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel();
63+
64+
// Returns the selected BF16 SBGEMM ukernel used by the KleidiAI based on runtime CPU capabilities.
65+
const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel();

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,42 @@ MlasGemmBatch(
105105
MLAS_THREADPOOL* ThreadPool
106106
);
107107

108+
#if defined(__aarch64__) && defined(__linux__)
109+
size_t
110+
MLASCALL
111+
MlasSBGemmPackBSize(
112+
CBLAS_TRANSPOSE TransA,
113+
CBLAS_TRANSPOSE TransB,
114+
size_t N,
115+
size_t K
116+
);
117+
118+
bool
119+
MLASCALL
120+
MlasSBGemmPackB(
121+
CBLAS_TRANSPOSE TransA,
122+
CBLAS_TRANSPOSE TransB,
123+
size_t N,
124+
size_t K,
125+
const float* B,
126+
size_t ldb,
127+
void* PackedB
128+
);
129+
130+
bool
131+
MLASCALL
132+
MlasSBGemmBatch(
133+
CBLAS_TRANSPOSE TransA,
134+
CBLAS_TRANSPOSE TransB,
135+
size_t M,
136+
size_t N,
137+
size_t K,
138+
const MLAS_SBGEMM_DATA_PARAMS* Data,
139+
size_t BatchSize,
140+
MLAS_THREADPOOL* ThreadPool
141+
);
142+
#endif
143+
108144
size_t
109145
MLASCALL
110146
MlasDynamicQGemmPackBSize(

0 commit comments

Comments
 (0)