Skip to content

Commit e64c694

Browse files
Adding QGemm pointers and ukernel interface
1 parent a98c912 commit e64c694

File tree

8 files changed

+145
-107
lines changed

8 files changed

+145
-107
lines changed

onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"
2020
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
2121
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
22+
23+
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h"
24+
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
2225
#if defined(ENABLE_QMX_KERNELS)
2326
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa.h"
27+
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa.h"
2428
#endif // ENABLE_QMX_KERNELS
2529

2630
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod =
@@ -125,6 +129,32 @@ const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme2 =
125129
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
126130
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa};
127131

132+
const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm_sme =
133+
{kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
134+
kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
135+
kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
136+
kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
137+
kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
138+
kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
139+
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
140+
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
141+
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
142+
kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa,
143+
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa};
144+
145+
const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm_sme2 =
146+
{kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
147+
kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
148+
kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
149+
kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
150+
kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
151+
kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
152+
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
153+
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
154+
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
155+
kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
156+
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa};
157+
128158
#if defined(ENABLE_QMX_KERNELS)
129159
const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_qmx =
130160
{kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa,
@@ -138,6 +168,19 @@ const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_qmx =
138168
kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa,
139169
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa,
140170
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa};
171+
172+
const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm_qmx =
173+
{kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
174+
kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
175+
kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
176+
kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
177+
kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
178+
kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
179+
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
180+
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
181+
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
182+
kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa,
183+
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa};
141184
#endif // ENABLE_QMX_KERNELS
142185

143186
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() {
@@ -181,3 +224,21 @@ const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() {
181224
return sgemm_gemv_sme;
182225
}
183226
}
227+
228+
const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel& GetKleidiAIQGemmUKernel() {
229+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
230+
return qgemm_gemm_sme2;
231+
} else {
232+
#if defined(ENABLE_QMX_KERNELS)
233+
if (ArmKleidiAI::vendor_name.compare("Qualcomm") == 0)
234+
{
235+
KLEIDIAI_KERNEL_LOG("QGEMM: Using QMX Kernel");
236+
return qgemm_gemm_qmx;
237+
} else {
238+
return qgemm_gemm_sme;
239+
}
240+
#else
241+
return qgemm_gemm_sme;
242+
#endif // ENABLE_QMX_KERNELS
243+
}
244+
}

onnxruntime/core/mlas/lib/kai_ukernel_interface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212

1313
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h"
1414

15+
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
16+
1517
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel();
1618
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel();
1719

1820
const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel();
1921
const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel();
22+
23+
const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel& GetKleidiAIQGemmUKernel();

onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,14 @@ MlasGemmBatch(
107107

108108
size_t
109109
MLASCALL
110-
MlasDynamicQgemmPackBSize(
110+
MlasDynamicQGemmPackBSize(
111111
size_t N,
112112
size_t K
113113
);
114114

115115
void
116116
MLASCALL
117-
MlasDynamicQgemmPackB(
117+
MlasDynamicQGemmPackB(
118118
size_t N,
119119
size_t K,
120120
const int8_t* B,

onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp

Lines changed: 19 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h"
1111
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"
1212

13-
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
14-
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h"
15-
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
13+
#include "kai_ukernel_interface.h"
1614
#if defined(ENABLE_QMX_KERNELS)
1715
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa.h"
1816
#endif // ENABLE_QMX_KERNELS
@@ -26,11 +24,13 @@ struct KaiTlsBuffersQgemm {
2624
};
2725
static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm;
2826

27+
const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm = GetKleidiAIQGemmUKernel();
28+
2929
// Matmul with float output of dynamic-quantized A and symmetric-quantized B.
3030

3131
size_t
3232
MLASCALL
33-
ArmKleidiAI::MlasDynamicQgemmPackBSize(
33+
ArmKleidiAI::MlasDynamicQGemmPackBSize(
3434
size_t N,
3535
size_t K
3636
) {
@@ -39,10 +39,9 @@ ArmKleidiAI::MlasDynamicQgemmPackBSize(
3939
return 0;
4040
}
4141

42-
// Default to sme2_mopa, but this may not always be the most optimal kernel variant to use.
43-
auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
44-
auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
45-
auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
42+
auto nr = qgemm_gemm.get_nr();
43+
auto kr = qgemm_gemm.get_kr();
44+
auto sr = qgemm_gemm.get_sr();
4645

4746
// Regardless of kernel variant, use the NEON packing variant.
4847
KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon Groups=1"
@@ -52,7 +51,7 @@ ArmKleidiAI::MlasDynamicQgemmPackBSize(
5251

5352
void
5453
MLASCALL
55-
ArmKleidiAI::MlasDynamicQgemmPackB(
54+
ArmKleidiAI::MlasDynamicQGemmPackB(
5655
size_t N,
5756
size_t K,
5857
const int8_t* B,
@@ -65,10 +64,9 @@ ArmKleidiAI::MlasDynamicQgemmPackB(
6564
return;
6665
}
6766

68-
// Default to sme2_mopa, but this may not always be the most optimal kernel variant to use.
69-
auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
70-
auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
71-
auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
67+
auto nr = qgemm_gemm.get_nr();
68+
auto kr = qgemm_gemm.get_kr();
69+
auto sr = qgemm_gemm.get_sr();
7270

7371
// y - float output
7472
// scale_factor_lhs - lhs scaling factor
@@ -105,17 +103,12 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
105103
MLAS_THREADPOOL* ThreadPool
106104
) {
107105

108-
const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
109-
: kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
110-
const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
111-
: kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
112-
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
113-
: kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
106+
const size_t mr = qgemm_gemm.get_mr();
107+
const size_t kr = qgemm_gemm.get_kr();
108+
const size_t sr = qgemm_gemm.get_sr();
114109

115-
size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
116-
: kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
117-
size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
118-
: kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
110+
size_t m_step = qgemm_gemm.get_m_step();
111+
size_t n_step = qgemm_gemm.get_n_step();
119112

120113
if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) {
121114
return;
@@ -216,17 +209,13 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
216209
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];
217210

218211
// Get rhs tile, B
219-
const size_t rhs_packed_offset =
220-
UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(NIdx * n_step, Shape.K)
221-
: kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(NIdx * n_step, Shape.K);
212+
const size_t rhs_packed_offset = qgemm_gemm.get_rhs_packed_offset(NIdx * n_step, Shape.K);
222213

223214
const std::byte* B_base = reinterpret_cast<const std::byte*>(DataParams[BIdx].PackedB);
224215
auto BTile = reinterpret_cast<const void*>(B_base + rhs_packed_offset);
225216

226217
// Get lhs tile, A
227-
const size_t lhs_packed_offset =
228-
UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(MIdx * m_step, Shape.K)
229-
: kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(MIdx * m_step, Shape.K);
218+
const size_t lhs_packed_offset =qgemm_gemm.get_lhs_packed_offset(MIdx * m_step, Shape.K);
230219

231220
const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace;
232221
auto ATile = reinterpret_cast<const std::byte*>(A_base + lhs_packed_offset);
@@ -240,46 +229,12 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
240229
NIdx * n_step * sizeof(float)
241230
);
242231

243-
if (UseSME2) {
244-
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(
232+
qgemm_gemm.run_matmul(
245233
TileSizeM, TileSizeN, Shape.K, ATile, BTile,
246234
dst_tile,
247235
DataParams[BIdx].ldc * sizeof(float),
248236
sizeof(float),
249237
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
250238
);
251-
}
252-
else {
253-
#if defined(ENABLE_QMX_KERNELS)
254-
if(ArmKleidiAI::vendor_name.compare("Qualcomm") == 0)
255-
{
256-
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa(
257-
TileSizeM, TileSizeN, Shape.K, ATile, BTile,
258-
dst_tile,
259-
DataParams[BIdx].ldc * sizeof(float),
260-
sizeof(float),
261-
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
262-
);
263-
}
264-
else
265-
{
266-
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(
267-
TileSizeM, TileSizeN, Shape.K, ATile, BTile,
268-
dst_tile,
269-
DataParams[BIdx].ldc * sizeof(float),
270-
sizeof(float),
271-
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
272-
);
273-
}
274-
#else
275-
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(
276-
TileSizeM, TileSizeN, Shape.K, ATile, BTile,
277-
dst_tile,
278-
DataParams[BIdx].ldc * sizeof(float),
279-
sizeof(float),
280-
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
281-
);
282-
#endif // ENABLE_QMX_KERNELS
283-
}
284239
});
285240
}

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -851,17 +851,9 @@ bool
851851
MLAS_THREADPOOL* ThreadPool
852852
);
853853

854-
typedef void (MLASCALL MLAS_GEMM_BATCH)(
855-
CBLAS_TRANSPOSE TransA,
856-
CBLAS_TRANSPOSE TransB,
857-
size_t M,
858-
size_t N,
859-
size_t K,
860-
const MLAS_SGEMM_DATA_PARAMS* Data,
861-
size_t BatchSize,
862-
MLAS_THREADPOOL* ThreadPool);
863-
864-
typedef bool (MLASCALL MLAS_GEMM_BATCH_OVERRIDE)(
854+
typedef
855+
bool
856+
(MLASCALL MLAS_SGEMM_BATCH_OVERRIDE)(
865857
CBLAS_TRANSPOSE TransA,
866858
CBLAS_TRANSPOSE TransB,
867859
size_t M,
@@ -871,19 +863,17 @@ typedef bool (MLASCALL MLAS_GEMM_BATCH_OVERRIDE)(
871863
size_t BatchSize,
872864
MLAS_THREADPOOL* ThreadPool);
873865

874-
typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE)(
875-
CBLAS_TRANSPOSE TransA,
876-
CBLAS_TRANSPOSE TransB,
877-
size_t N,
878-
size_t K);
879-
880-
typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)(
866+
typedef
867+
size_t
868+
(MLASCALL MLAS_SGEMM_PACK_B_SIZE_OVERRIDE)(
881869
CBLAS_TRANSPOSE TransA,
882870
CBLAS_TRANSPOSE TransB,
883871
size_t N,
884872
size_t K);
885873

886-
typedef void (MLASCALL MLAS_GEMM_PACK_B)(
874+
typedef
875+
bool
876+
(MLASCALL MLAS_SGEMM_PACK_B_OVERRIDE)(
887877
CBLAS_TRANSPOSE TransA,
888878
CBLAS_TRANSPOSE TransB,
889879
size_t N,
@@ -892,13 +882,28 @@ typedef void (MLASCALL MLAS_GEMM_PACK_B)(
892882
size_t ldb,
893883
void* PackedB);
894884

895-
typedef bool (MLASCALL MLAS_GEMM_PACK_B_OVERRIDE)(
896-
CBLAS_TRANSPOSE TransA,
897-
CBLAS_TRANSPOSE TransB,
885+
typedef
886+
void
887+
(MLASCALL MLAS_DYNAMIC_QGEMM_BATCH_OVERRIDE)(
888+
const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape,
889+
const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams,
890+
const size_t BatchN,
891+
MLAS_THREADPOOL* ThreadPool);
892+
893+
typedef
894+
size_t
895+
(MLASCALL MLAS_DYNAMIC_QGEMM_PACK_B_SIZE_OVERRIDE)(
896+
size_t N,
897+
size_t K);
898+
899+
typedef
900+
void
901+
(MLASCALL MLAS_DYNAMIC_QGEMM_PACK_B_OVERRIDE)(
898902
size_t N,
899903
size_t K,
900-
const float* B,
901-
size_t ldb,
904+
const int8_t* B,
905+
const float* Scales,
906+
const float* Bias,
902907
void* PackedB);
903908

904909
extern "C" {
@@ -1348,10 +1353,15 @@ struct MLAS_PLATFORM {
13481353
bool Avx512Supported_ = false;
13491354
bool ArmNeonIsQuantActivationsUnsigned = false;
13501355

1351-
// Mlas overrides initialisation
1352-
MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr;
1353-
MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr;
1354-
MLAS_GEMM_PACK_B_OVERRIDE* MlasGemmPackBOverride = nullptr;
1356+
// MLAS SGemm overrides
1357+
MLAS_SGEMM_BATCH_OVERRIDE* MlasSGemmBatchOverride = nullptr;
1358+
MLAS_SGEMM_PACK_B_SIZE_OVERRIDE* MlasSGemmPackBSizeOverride = nullptr;
1359+
MLAS_SGEMM_PACK_B_OVERRIDE* MlasSGemmPackBOverride = nullptr;
1360+
// MLAS Dynamic QGemm overrides
1361+
MLAS_DYNAMIC_QGEMM_BATCH_OVERRIDE* MlasDynamicQGemmBatchOverride = nullptr;
1362+
MLAS_DYNAMIC_QGEMM_PACK_B_SIZE_OVERRIDE* MlasDynamicQGemmPackBSizeOverride = nullptr;
1363+
MLAS_DYNAMIC_QGEMM_PACK_B_OVERRIDE* MlasDynamicQGemmPackBOverride = nullptr;
1364+
// MLAS Conv overrides
13551365
MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr;
13561366
MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr;
13571367

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,12 @@ Return Value:
618618

619619
#if defined(USE_KLEIDIAI)
620620
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){
621-
this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch;
622-
this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize;
623-
this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB;
621+
this->MlasSGemmBatchOverride = ArmKleidiAI::MlasGemmBatch;
622+
this->MlasSGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize;
623+
this->MlasSGemmPackBOverride = ArmKleidiAI::MlasGemmPackB;
624+
this->MlasDynamicQGemmBatchOverride = ArmKleidiAI::MlasDynamicQGemmBatch;
625+
this->MlasDynamicQGemmPackBSizeOverride = ArmKleidiAI::MlasDynamicQGemmPackBSize;
626+
this->MlasDynamicQGemmPackBOverride = ArmKleidiAI::MlasDynamicQGemmPackB;
624627
this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare;
625628
this->MlasConvOverride = ArmKleidiAI::MlasConv;
626629
}

0 commit comments

Comments
 (0)