Skip to content

Commit bb7d2cd

Browse files
Adding support for BackendKernelSelectorConfig in SBGEMM
Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com>
1 parent 8c2008e commit bb7d2cd

File tree

6 files changed

+67
-45
lines changed

6 files changed

+67
-45
lines changed

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,14 +2016,17 @@ struct MLAS_SBGEMM_DATA_PARAMS {
20162016
* Note: We only support uniform batching, so shapes and types of the
20172017
* input must be same across all parameter blocks.
20182018
*
2019-
* @param[in] TransA Supplies the transpose operation for matrix A.
2020-
* @param[in] TransB Supplies the transpose operation for matrix B.
2021-
* @param[in] M row size of matrix A and C
2022-
* @param[in] N column size of matrix B and C
2023-
* @param[in] K column size of matrix A and row size of matrix B
2024-
* @param[in] BatchN number of batches
2025-
* @param[inout] DataParams An array (size BatchN) of parameter blocks
2019+
* @param[in] TransA Supplies the transpose operation for matrix A.
2020+
* @param[in] TransB Supplies the transpose operation for matrix B.
2021+
* @param[in] M row size of matrix A and C
2022+
* @param[in] N column size of matrix B and C
2023+
* @param[in] K column size of matrix A and row size of matrix B
2024+
* @param[in] BatchN number of batches
2025+
* @param[inout] DataParams An array (size BatchN) of parameter blocks
20262026
* @param[in] ThreadPool
2027+
* @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector
2028+
configuration options, else nullptr if the
2029+
default configuration should be used.
20272030
* @return
20282031
*/
20292032
void MLASCALL
@@ -2035,41 +2038,49 @@ MlasSBGemmBatch(
20352038
const size_t K,
20362039
const size_t BatchN,
20372040
const MLAS_SBGEMM_DATA_PARAMS* DataParams,
2038-
MLAS_THREADPOOL* ThreadPool = nullptr
2041+
MLAS_THREADPOOL* ThreadPool,
2042+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
20392043
);
20402044

20412045
/**
20422046
* @brief For bfloat16 precision GEMM, returns size of the
20432047
* packing buffer needed for right hand side
2044-
* @param[in] TransA Supplies the transpose operation for matrix A.
2045-
* @param[in] TransB Supplies the transpose operation for matrix B.
2046-
* @param[in] BIsfp32 Is matrix B datatype FP32
2047-
* @param[in] N Number of columns
2048-
* @param[in] K Number of rows
2049-
* @return size of the packing buffer,
2050-
* 0 if operation not supported
2048+
* @param[in] TransA Supplies the transpose operation for matrix A.
2049+
* @param[in] TransB Supplies the transpose operation for matrix B.
2050+
* @param[in] BIsfp32 Is matrix B datatype FP32
2051+
* @param[in] N Number of columns
2052+
* @param[in] K Number of rows
2053+
* @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector
2054+
configuration options, else nullptr if the
2055+
default configuration should be used.
2056+
* @return size of the packing buffer,
2057+
* 0 if operation not supported
20512058
*/
20522059
size_t MLASCALL
20532060
MlasSBGemmPackBSize(
20542061
CBLAS_TRANSPOSE TransA,
20552062
CBLAS_TRANSPOSE TransB,
20562063
bool BIsfp32,
20572064
size_t N,
2058-
size_t K
2065+
size_t K,
2066+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
20592067
);
20602068

20612069
/**
20622070
* @brief For bfloat16 precision GEMM, convert the float matrix B
20632071
* to blfoat16 precision and pack it into a packing buffer
20642072
*
2065-
* @param[in] TransA Supplies the transpose operation for matrix A.
2066-
* @param[in] TransB Supplies the transpose operation for matrix B.
2067-
* @param[in] BIsfp32 Is matrix B datatype FP32
2068-
* @param[in] N Number of columns
2069-
* @param[in] K Number of rows
2070-
* @param[in] B Address of matrix B
2071-
* @param[in] ldb leading dimension of input matrix B
2072-
* @param[out] PackedB Address of the packed matrix
2073+
* @param[in] TransA Supplies the transpose operation for matrix A.
2074+
* @param[in] TransB Supplies the transpose operation for matrix B.
2075+
* @param[in] BIsfp32 Is matrix B datatype FP32
2076+
* @param[in] N Number of columns
2077+
* @param[in] K Number of rows
2078+
* @param[in] B Address of matrix B
2079+
* @param[in] ldb leading dimension of input matrix B
2080+
* @param[out] PackedB Address of the packed matrix
2081+
* @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector
2082+
configuration options, else nullptr if the
2083+
default configuration should be used.
20732084
*/
20742085
void MLASCALL
20752086
MlasSBGemmConvertPackB(
@@ -2080,7 +2091,8 @@ MlasSBGemmConvertPackB(
20802091
size_t K,
20812092
const float* B,
20822093
size_t ldb,
2083-
void* PackedB
2094+
void* PackedB,
2095+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
20842096
);
20852097
#endif
20862098

onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ MlasConvPointwiseBf16KernelNeon(
9191
}
9292
}
9393

94-
MlasSBGemmBatch(OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr);
94+
MlasSBGemmBatch(CblasNoTrans, CblasNoTrans, OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr, nullptr);
9595

9696
if (ReluActivation) {
9797
const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f);

onnxruntime/core/mlas/lib/sbgemm.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,17 @@ MlasSBGemmPackBSize(
303303
CBLAS_TRANSPOSE TransA,
304304
CBLAS_TRANSPOSE TransB,
305305
bool BIsfp32,
306-
size_t N,
307-
size_t K)
306+
size_t N,
307+
size_t K,
308+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
309+
)
308310
{
309311
//
310312
// Compute the number of bytes required to hold the packed buffer.
311313
//
312314
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) && !defined(MLAS_USE_ARM_NEON_NCHWC)
313-
if (GetMlasPlatform().MlasSBGemmPackBSizeOverride != nullptr &&
315+
if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) &&
316+
GetMlasPlatform().MlasSBGemmPackBSizeOverride != nullptr &&
314317
TransA != CBLAS_TRANSPOSE::CblasTrans &&
315318
TransB != CBLAS_TRANSPOSE::CblasTrans &&
316319
BIsfp32) {
@@ -348,11 +351,13 @@ MlasSBGemmConvertPackB(
348351
size_t K,
349352
const float* B,
350353
size_t ldb,
351-
void* PackedB
354+
void* PackedB,
355+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
352356
)
353357
{
354358
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) && !defined(MLAS_USE_ARM_NEON_NCHWC)
355-
if (GetMlasPlatform().MlasSBGemmPackBOverride != nullptr &&
359+
if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) &&
360+
GetMlasPlatform().MlasSBGemmPackBOverride != nullptr &&
356361
TransA != CBLAS_TRANSPOSE::CblasTrans &&
357362
TransB != CBLAS_TRANSPOSE::CblasTrans &&
358363
BIsfp32 &&
@@ -376,11 +381,13 @@ MlasSBGemmBatch(
376381
const size_t K,
377382
const size_t BatchN,
378383
const MLAS_SBGEMM_DATA_PARAMS* Data,
379-
MLAS_THREADPOOL* ThreadPool
384+
MLAS_THREADPOOL* ThreadPool,
385+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig
380386
)
381387
{
382388
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) && !defined(MLAS_USE_ARM_NEON_NCHWC)
383-
if (GetMlasPlatform().MlasSBGemmBatchOverride != nullptr &&
389+
if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) &&
390+
GetMlasPlatform().MlasSBGemmBatchOverride != nullptr &&
384391
TransA != CBLAS_TRANSPOSE::CblasTrans &&
385392
TransB != CBLAS_TRANSPOSE::CblasTrans &&
386393
Data->AIsfp32 &&

onnxruntime/core/providers/cpu/math/matmul.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc,
141141
bool trans_b,
142142
IAllocatorUniquePtr<void>& packed_b,
143143
size_t& packed_b_size,
144-
TensorShape& b_shape) {
144+
TensorShape& b_shape,
145+
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) {
145146
// Only handle the common case of a 2D weight matrix. Additional matrices
146147
// could be handled by stacking the packed buffers.
147148
if (tensor_b.Shape().NumDimensions() != 2) {
@@ -157,7 +158,8 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc,
157158
trans_b ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans,
158159
true,
159160
N,
160-
K);
161+
K,
162+
mlas_backend_kernel_selector_config);
161163
if (packed_b_size == 0) {
162164
return false;
163165
}
@@ -176,7 +178,8 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc,
176178
K,
177179
tensor_b.Data<float>(),
178180
trans_b ? K : N,
179-
packed_b_data);
181+
packed_b_data,
182+
mlas_backend_kernel_selector_config);
180183
return true;
181184
}
182185
#endif
@@ -200,7 +203,7 @@ Status MatMul<float>::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc
200203
}
201204

202205
if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) {
203-
is_packed = GemmPackBBfloat16(alloc, tensor, trans_a_attr_ != 0, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_);
206+
is_packed = GemmPackBBfloat16(alloc, tensor, trans_a_attr_ != 0, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_, &mlas_backend_kernel_selector_config_);
204207
} else
205208
#endif
206209
{
@@ -284,7 +287,7 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
284287
data[i].BIsPacked = static_cast<bool>(packed_b_);
285288
}
286289
MlasSBGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans,
287-
M, N, K, max_len, data.data(), thread_pool);
290+
M, N, K, max_len, data.data(), thread_pool, &mlas_backend_kernel_selector_config_);
288291
} else
289292
#endif
290293
{

onnxruntime/test/mlas/unittest/test_sbgemm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ static size_t SBGemmRegistLongExecute() {
141141
size_t count = 0;
142142

143143
count += MlasLongExecuteTests<MlasSBGemmTest<float, float, false, false>>::RegisterLongExecute();
144-
if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128) > 0) {
144+
if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) {
145145
count += MlasLongExecuteTests<MlasSBGemmTest<float, float, true, false>>::RegisterLongExecute();
146146
}
147147

148148
if (GetMlasThreadPool() != nullptr) {
149149
count += MlasLongExecuteTests<MlasSBGemmTest<float, float, false, true>>::RegisterLongExecute();
150-
if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128) > 0) {
150+
if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) {
151151
count += MlasLongExecuteTests<MlasSBGemmTest<float, float, true, true>>::RegisterLongExecute();
152152
}
153153
}
@@ -160,15 +160,15 @@ static size_t SBGemmRegistShortExecute() {
160160

161161
count += SBGemmShortExecuteTest<float, float, false, false>::RegisterShortExecuteTests();
162162
count += SBGemmAccumulateExecuteTest<float, float, false, false>::RegisterAccumulateTests();
163-
if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128) > 0) {
163+
if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) {
164164
count += SBGemmShortExecuteTest<float, float, true, false>::RegisterShortExecuteTests();
165165
count += SBGemmAccumulateExecuteTest<float, float, true, false>::RegisterAccumulateTests();
166166
}
167167

168168
if (GetMlasThreadPool() != nullptr) {
169169
count += SBGemmShortExecuteTest<float, float, false, true>::RegisterShortExecuteTests();
170170
count += SBGemmAccumulateExecuteTest<float, float, false, true>::RegisterAccumulateTests();
171-
if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128) > 0) {
171+
if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) {
172172
count += SBGemmShortExecuteTest<float, float, true, true>::RegisterShortExecuteTests();
173173
count += SBGemmAccumulateExecuteTest<float, float, true, true>::RegisterAccumulateTests();
174174
}

onnxruntime/test/mlas/unittest/test_sbgemm.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ class MlasSBGemmTest : public MlasTestBase {
6262

6363
void* PackB(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, const BType* B, size_t ldb) {
6464
const bool BIsfp32 = std::is_same<BType, float>::value;
65-
size_t PackedBSize = MlasSBGemmPackBSize(TransA, TransB, BIsfp32, N, K);
65+
size_t PackedBSize = MlasSBGemmPackBSize(TransA, TransB, BIsfp32, N, K, nullptr);
6666
if (PackedBSize == 0) {
6767
return nullptr;
6868
}
6969
void* PackedB = BufferBPacked.GetBuffer(PackedBSize);
7070
if (std::is_same<BType, float>::value) {
71-
MlasSBGemmConvertPackB(TransA, TransB, true, N, K, (const float*)B, ldb, PackedB);
71+
MlasSBGemmConvertPackB(TransA, TransB, true, N, K, (const float*)B, ldb, PackedB, nullptr);
7272
} else {
7373
}
7474
return PackedB;
@@ -118,7 +118,7 @@ class MlasSBGemmTest : public MlasTestBase {
118118
}
119119
}
120120

121-
MlasSBGemmBatch(TransA, TransB, M, N, K, BatchSize, GemmParameters.data(), threadpool_);
121+
MlasSBGemmBatch(TransA, TransB, M, N, K, BatchSize, GemmParameters.data(), threadpool_, nullptr);
122122
}
123123

124124
void ReferenceSgemm(size_t M,

0 commit comments

Comments
 (0)