1010#include " kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h"
1111#include " kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"
1212
13- #include " kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
14- #include " kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h"
15- #include " kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
13+ #include " kai_ukernel_interface.h"
1614#if defined(ENABLE_QMX_KERNELS)
1715#include " kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa.h"
1816#endif // ENABLE_QMX_KERNELS
@@ -26,11 +24,13 @@ struct KaiTlsBuffersQgemm {
2624};
2725static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm;
2826
27+ const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm = GetKleidiAIQGemmUKernel();
28+
2929// Matmul with float output of dynamic-quantized A and symmetric-quantized B.
3030
3131size_t
3232MLASCALL
33- ArmKleidiAI::MlasDynamicQgemmPackBSize (
33+ ArmKleidiAI::MlasDynamicQGemmPackBSize (
3434 size_t N,
3535 size_t K
3636) {
@@ -39,10 +39,9 @@ ArmKleidiAI::MlasDynamicQgemmPackBSize(
3939 return 0 ;
4040 }
4141
42- // Default to sme2_mopa, but this may not always be the most optimal kernel variant to use.
43- auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ();
44- auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ();
45- auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ();
42+ auto nr = qgemm_gemm.get_nr ();
43+ auto kr = qgemm_gemm.get_kr ();
44+ auto sr = qgemm_gemm.get_sr ();
4645
4746 // Regardless of kernel variant, use the NEON packing variant.
4847 KLEIDIAI_KERNEL_LOG (" kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon Groups=1"
@@ -52,7 +51,7 @@ ArmKleidiAI::MlasDynamicQgemmPackBSize(
5251
5352void
5453MLASCALL
55- ArmKleidiAI::MlasDynamicQgemmPackB (
54+ ArmKleidiAI::MlasDynamicQGemmPackB (
5655 size_t N,
5756 size_t K,
5857 const int8_t * B,
@@ -65,10 +64,9 @@ ArmKleidiAI::MlasDynamicQgemmPackB(
6564 return ;
6665 }
6766
68- // Default to sme2_mopa, but this may not always be the most optimal kernel variant to use.
69- auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ();
70- auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ();
71- auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ();
67+ auto nr = qgemm_gemm.get_nr ();
68+ auto kr = qgemm_gemm.get_kr ();
69+ auto sr = qgemm_gemm.get_sr ();
7270
7371 // y - float output
7472 // scale_factor_lhs - lhs scaling factor
@@ -105,17 +103,12 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
105103 MLAS_THREADPOOL* ThreadPool
106104) {
107105
108- const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ()
109- : kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa ();
110- const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ()
111- : kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa ();
112- const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ()
113- : kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa ();
106+ const size_t mr = qgemm_gemm.get_mr ();
107+ const size_t kr = qgemm_gemm.get_kr ();
108+ const size_t sr = qgemm_gemm.get_sr ();
114109
115- size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ()
116- : kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa ();
117- size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa ()
118- : kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa ();
110+ size_t m_step = qgemm_gemm.get_m_step ();
111+ size_t n_step = qgemm_gemm.get_n_step ();
119112
120113 if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0 ) {
121114 return ;
@@ -216,17 +209,13 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
216209 ptrdiff_t NIdx = (tid % (dim[1 ] * dim[2 ])) % dim[2 ];
217210
218211 // Get rhs tile, B
219- const size_t rhs_packed_offset =
220- UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa (NIdx * n_step, Shape.K )
221- : kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa (NIdx * n_step, Shape.K );
212+ const size_t rhs_packed_offset = qgemm_gemm.get_rhs_packed_offset (NIdx * n_step, Shape.K );
222213
223214 const std::byte* B_base = reinterpret_cast <const std::byte*>(DataParams[BIdx].PackedB );
224215 auto BTile = reinterpret_cast <const void *>(B_base + rhs_packed_offset);
225216
226217 // Get lhs tile, A
227- const size_t lhs_packed_offset =
228- UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa (MIdx * m_step, Shape.K )
229- : kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa (MIdx * m_step, Shape.K );
218+ const size_t lhs_packed_offset =qgemm_gemm.get_lhs_packed_offset (MIdx * m_step, Shape.K );
230219
231220 const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace;
232221 auto ATile = reinterpret_cast <const std::byte*>(A_base + lhs_packed_offset);
@@ -240,46 +229,12 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
240229 NIdx * n_step * sizeof (float )
241230 );
242231
243- if (UseSME2) {
244- kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa (
232+ qgemm_gemm.run_matmul (
245233 TileSizeM, TileSizeN, Shape.K , ATile, BTile,
246234 dst_tile,
247235 DataParams[BIdx].ldc * sizeof (float ),
248236 sizeof (float ),
249237 -std::numeric_limits<float >::max (), std::numeric_limits<float >::max ()
250238 );
251- }
252- else {
253- #if defined(ENABLE_QMX_KERNELS)
254- if (ArmKleidiAI::vendor_name.compare (" Qualcomm" ) == 0 )
255- {
256- kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa (
257- TileSizeM, TileSizeN, Shape.K , ATile, BTile,
258- dst_tile,
259- DataParams[BIdx].ldc * sizeof (float ),
260- sizeof (float ),
261- -std::numeric_limits<float >::max (), std::numeric_limits<float >::max ()
262- );
263- }
264- else
265- {
266- kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa (
267- TileSizeM, TileSizeN, Shape.K , ATile, BTile,
268- dst_tile,
269- DataParams[BIdx].ldc * sizeof (float ),
270- sizeof (float ),
271- -std::numeric_limits<float >::max (), std::numeric_limits<float >::max ()
272- );
273- }
274- #else
275- kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa (
276- TileSizeM, TileSizeN, Shape.K , ATile, BTile,
277- dst_tile,
278- DataParams[BIdx].ldc * sizeof (float ),
279- sizeof (float ),
280- -std::numeric_limits<float >::max (), std::numeric_limits<float >::max ()
281- );
282- #endif // ENABLE_QMX_KERNELS
283- }
284239 });
285240}
0 commit comments