Skip to content

Commit 6a201f1

Browse files
Fixing logging + build issue
Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com>
1 parent 4bd9b5e commit 6a201f1

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
1010
#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h"
11-
#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
1211

1312
#include "mlas.h"
1413

@@ -24,7 +23,7 @@ struct KaiTlsBuffersSbgemm {
2423
};
2524
static thread_local KaiTlsBuffersSbgemm g_kai_tls_sbgemm;
2625

27-
kai_matmul_clamp_f32_bf16p_bf16p_ukernel sbgemm_gemm = GetKleidiAISBGemmUKernel();
26+
KaiBF16SBgemmKernel sbgemm_gemm = GetKleidiAISBGemmUKernel();
2827

2928
/*++
3029
Routine Description:
@@ -192,9 +191,9 @@ Return Value:
192191
return false;
193192
}
194193

195-
const size_t nr = sbgemm_gemm.get_nr();
196-
const size_t kr = sbgemm_gemm.get_kr();
197-
const size_t sr = sbgemm_gemm.get_sr();
194+
const size_t nr = sbgemm_gemm.ukernel.get_nr();
195+
const size_t kr = sbgemm_gemm.ukernel.get_kr();
196+
const size_t sr = sbgemm_gemm.ukernel.get_sr();
198197

199198
// Ensure size and zero the used span.
200199
g_kai_tls_sbgemm.bias_zero.resize(N, 0.0f);
@@ -257,17 +256,17 @@ Return Value:
257256
return true;
258257
}
259258

260-
size_t m_step = sbgemm_gemm.get_m_step();
261-
size_t n_step = sbgemm_gemm.get_n_step();
259+
size_t m_step = sbgemm_gemm.ukernel.get_m_step();
260+
size_t n_step = sbgemm_gemm.ukernel.get_n_step();
262261

263262
if ((M < m_step || N < n_step) && !Data->BIsPacked) {
264263
// Fallback
265264
return false;
266265
}
267266

268-
const size_t mr = sbgemm_gemm.get_mr();
269-
const size_t kr = sbgemm_gemm.get_kr();
270-
const size_t sr = sbgemm_gemm.get_sr();
267+
const size_t mr = sbgemm_gemm.ukernel.get_mr();
268+
const size_t kr = sbgemm_gemm.ukernel.get_kr();
269+
const size_t sr = sbgemm_gemm.ukernel.get_sr();
271270

272271
size_t LhsPackedStride = 0;
273272
std::byte* LhsPackedData = nullptr;
@@ -362,15 +361,15 @@ Return Value:
362361
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];
363362

364363
// Get rhs tile, B
365-
const size_t rhs_packed_offset = sbgemm_gemm.get_rhs_packed_offset(NIdx * n_step, K);
364+
const size_t rhs_packed_offset = sbgemm_gemm.ukernel.get_rhs_packed_offset(NIdx * n_step, K);
366365

367366
const std::byte* B_base = Data[0].BIsPacked
368367
? reinterpret_cast<const std::byte*>(Data[BIdx].B)
369368
: (RhsPackedData + RhsPackedStride * BIdx);
370369
auto BTile = reinterpret_cast<const void*>(B_base + rhs_packed_offset);
371370

372371
// Get lhs tile, A
373-
const size_t lhs_packed_offset = sbgemm_gemm.get_lhs_packed_offset(MIdx * m_step, K);
372+
const size_t lhs_packed_offset = sbgemm_gemm.ukernel.get_lhs_packed_offset(MIdx * m_step, K);
374373

375374
const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx;
376375
auto ATile = reinterpret_cast<const float*>(A_base + lhs_packed_offset);
@@ -393,7 +392,9 @@ Return Value:
393392
float* temp_tile = g_kai_tls_sbgemm.output_tile.data();
394393
std::fill_n(temp_tile, tile_elems, 0.0f);
395394

396-
sbgemm_gemm.run_matmul(
395+
KLEIDIAI_KERNEL_LOG(sbgemm_gemm.name
396+
<< " M=" << TileSizeM << " N=" << TileSizeN << " K=" << K);
397+
sbgemm_gemm.ukernel.run_matmul(
397398
TileSizeM,
398399
TileSizeN,
399400
K,

0 commit comments

Comments
 (0)