88
99#include " kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
1010#include " kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h"
11- #include " kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
1211
1312#include " mlas.h"
1413
@@ -24,7 +23,7 @@ struct KaiTlsBuffersSbgemm {
2423};
2524static thread_local KaiTlsBuffersSbgemm g_kai_tls_sbgemm;
2625
27- kai_matmul_clamp_f32_bf16p_bf16p_ukernel sbgemm_gemm = GetKleidiAISBGemmUKernel();
26+ KaiBF16SBgemmKernel sbgemm_gemm = GetKleidiAISBGemmUKernel();
2827
2928/* ++
3029Routine Description:
@@ -192,9 +191,9 @@ Return Value:
192191 return false ;
193192 }
194193
195- const size_t nr = sbgemm_gemm.get_nr ();
196- const size_t kr = sbgemm_gemm.get_kr ();
197- const size_t sr = sbgemm_gemm.get_sr ();
194+ const size_t nr = sbgemm_gemm.ukernel . get_nr ();
195+ const size_t kr = sbgemm_gemm.ukernel . get_kr ();
196+ const size_t sr = sbgemm_gemm.ukernel . get_sr ();
198197
199198 // Ensure size and zero the used span.
200199 g_kai_tls_sbgemm.bias_zero .resize (N, 0 .0f );
@@ -257,17 +256,17 @@ Return Value:
257256 return true ;
258257 }
259258
260- size_t m_step = sbgemm_gemm.get_m_step ();
261- size_t n_step = sbgemm_gemm.get_n_step ();
259+ size_t m_step = sbgemm_gemm.ukernel . get_m_step ();
260+ size_t n_step = sbgemm_gemm.ukernel . get_n_step ();
262261
263262 if ((M < m_step || N < n_step) && !Data->BIsPacked ) {
264263 // Fallback
265264 return false ;
266265 }
267266
268- const size_t mr = sbgemm_gemm.get_mr ();
269- const size_t kr = sbgemm_gemm.get_kr ();
270- const size_t sr = sbgemm_gemm.get_sr ();
267+ const size_t mr = sbgemm_gemm.ukernel . get_mr ();
268+ const size_t kr = sbgemm_gemm.ukernel . get_kr ();
269+ const size_t sr = sbgemm_gemm.ukernel . get_sr ();
271270
272271 size_t LhsPackedStride = 0 ;
273272 std::byte* LhsPackedData = nullptr ;
@@ -362,15 +361,15 @@ Return Value:
362361 ptrdiff_t NIdx = (tid % (dim[1 ] * dim[2 ])) % dim[2 ];
363362
364363 // Get rhs tile, B
365- const size_t rhs_packed_offset = sbgemm_gemm.get_rhs_packed_offset (NIdx * n_step, K);
364+ const size_t rhs_packed_offset = sbgemm_gemm.ukernel . get_rhs_packed_offset (NIdx * n_step, K);
366365
367366 const std::byte* B_base = Data[0 ].BIsPacked
368367 ? reinterpret_cast <const std::byte*>(Data[BIdx].B )
369368 : (RhsPackedData + RhsPackedStride * BIdx);
370369 auto BTile = reinterpret_cast <const void *>(B_base + rhs_packed_offset);
371370
372371 // Get lhs tile, A
373- const size_t lhs_packed_offset = sbgemm_gemm.get_lhs_packed_offset (MIdx * m_step, K);
372+ const size_t lhs_packed_offset = sbgemm_gemm.ukernel . get_lhs_packed_offset (MIdx * m_step, K);
374373
375374 const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx;
376375 auto ATile = reinterpret_cast <const float *>(A_base + lhs_packed_offset);
@@ -393,7 +392,9 @@ Return Value:
393392 float * temp_tile = g_kai_tls_sbgemm.output_tile .data ();
394393 std::fill_n (temp_tile, tile_elems, 0 .0f );
395394
396- sbgemm_gemm.run_matmul (
395+ KLEIDIAI_KERNEL_LOG (sbgemm_gemm.name
396+ << " M=" << TileSizeM << " N=" << TileSizeN << " K=" << K);
397+ sbgemm_gemm.ukernel .run_matmul (
397398 TileSizeM,
398399 TileSizeN,
399400 K,
0 commit comments