diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 369905750754f..70cec1814c184 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -504,9 +504,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.13.0") + set(KLEIDIAI_COMMIT_TAG "v1.14.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1") + set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -583,6 +583,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 8694ee15d3fe0..44691e5dfdf6a 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -87,15 +87,38 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } +template +constexpr bool variant_any_invocable_impl(std::index_sequence) { + using V = std::remove_reference_t; + return (std::is_invocable_r_v< + Ret, + std::variant_alternative_t, + Args...> || ...); +} + +template +constexpr bool variant_any_invocable_v = + variant_any_invocable_impl( + std::make_index_sequence< + std::variant_size_v>>{}); + template -static Ret variant_call(const Variant & var, Args&&... args) { - return std::visit([&](auto&& func) -> Ret { - if constexpr (std::is_invocable_r_v) { - return func(std::forward(args)...); - } else { - throw std::runtime_error("Invalid function type in variant_call"); - } - }, var); +static inline Ret variant_call(Variant && var, Args&&... args) { + static_assert(variant_any_invocable_v, Ret, Args...>, + "No alternative in Variant is invocable with the provided arguments and return type."); + + return std::visit( + [&](auto && f) -> Ret { + using F = std::decay_t; + if constexpr (std::is_invocable_r_v) { + return std::invoke(std::forward(f), std::forward(args)...); + } else { + GGML_ABORT("Invalid function type in variant_call"); + GGML_UNREACHABLE(); + } + }, + std::forward(var) + ); } namespace ggml::cpu::kleidiai { @@ -138,7 +161,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (kernels->rhs_type == GGML_TYPE_Q4_0) { size = variant_call(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); } else if (kernels->rhs_type == GGML_TYPE_F16) { - size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr) + + const int64_t lhs_batch_size0 = op->src[1]->ne[2]; + const int64_t rhs_batch_size0 = op->src[0]->ne[2]; + const int64_t r = lhs_batch_size0 / rhs_batch_size0; + size = variant_call(lhs_info->packed_size, m * r, k, mr, kr, sr) + variant_call(kernels->rhs_info.packed_size, n, k) + k * n * sizeof(float) + n * sizeof(float); } else { @@ -148,7 +174,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { return true; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { if (dst->src[0]->type == GGML_TYPE_Q4_0) { @@ -165,8 +190,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { } bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) { - static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; - const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -175,7 +198,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); GGML_ASSERT(kernels); - bool is_gemv = src1->ne[1] == 1; + const bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); @@ -185,27 +208,30 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t lhs_batch_size0 = ne12; const int64_t rhs_batch_size0 = ne02; - const int64_t batch_size = rhs_batch_size0; + const int64_t batch_size = lhs_batch_size0; + GGML_ASSERT(rhs_batch_size0 > 0); + GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0); const int64_t r = lhs_batch_size0 / rhs_batch_size0; - const int64_t m = ne11 * r; - const int64_t n = ne01; - const int64_t k = ne00; + const int64_t m_group = ne11; + const int64_t m = m_group; + const int64_t n = ne01; + const int64_t k = ne00; const size_t lhs_stride = src1->nb[1]; const size_t rhs_stride = src0->nb[1]; const size_t dst_stride = dst->nb[1]; - const int64_t mr = static_cast(kernel->get_mr()); - const int64_t nr = static_cast(kernel->get_nr()); - const int64_t kr = static_cast(kernel->get_kr()); - const int64_t sr = static_cast(kernel->get_sr()); + const int64_t mr = (int64_t) kernel->get_mr(); + const int64_t nr = (int64_t) kernel->get_nr(); + const int64_t kr = (int64_t) kernel->get_kr(); + const int64_t sr = (int64_t) kernel->get_sr(); - const size_t lhs_packed_size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr); - const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); - const size_t kxn_size = k * n * sizeof(float); - const size_t bias_size = n * sizeof(float); + const size_t lhs_packed_size = variant_call(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, (size_t)n, (size_t)k); + const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float); + const size_t bias_size = (size_t)n * sizeof(float); const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; GGML_ASSERT(wsize_required <= params->wsize); @@ -216,82 +242,102 @@ class tensor_traits : public ggml::cpu::tensor_traits { uint8_t * bias = rhs_kxn + kxn_size; for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - const uint8_t * lhs_batch = static_cast(src1->data) + batch_idx * m * lhs_stride; - const uint8_t * rhs_batch = static_cast(src0->data) + batch_idx * n * rhs_stride; - uint8_t * dst_batch = static_cast(dst->data) + batch_idx * m * dst_stride; + const int64_t rhs_batch_idx = batch_idx / r; + const uint8_t * rhs_batch_base = static_cast(src0->data) + rhs_batch_idx * src0->nb[2]; + uint8_t * dst_batch_base = static_cast(dst->data) + batch_idx * dst->nb[2]; - // LHS packing + // LHS packing (threaded over m, honoring mr alignment and KV groups) { const int64_t m_roundup_mr = kai_roundup(m, mr); const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); if (ith < num_threads) { - const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); + const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr); const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; - const int64_t m_start = ith * num_m_per_thread0; - const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + const int64_t m_start = ith * num_m_per_thread0; + const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + + // Base packed offset (aligned) and per-row stride in bytes + const size_t base_packed_off = variant_call( + lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t next_block_off = variant_call( + lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr; + + int64_t remaining = m_count; + int64_t cur = m_start; + + while (remaining > 0) { + const int64_t row_in_group = cur; + const int64_t avail = m_group - row_in_group; + const int64_t take = std::min(avail, remaining); - const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, mr, kr, sr); + const uint8_t * lhs_batch_base = static_cast(src1->data) + batch_idx * src1->nb[2]; + const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride; + const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; + void * dst_ptr = lhs_packed + dst_off; - const void * src_ptr = static_cast(lhs_batch) + lhs_offset; - void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; + variant_call(lhs_info->pack_func, + (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr, + /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr); - variant_call(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); + cur += take; + remaining -= take; + } } } - // RHS packing - if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) { - // First thread to reach this point handles RHS packing - memset(bias, 0, n * sizeof(float)); - transpose_f32kxn_f16nxk(n, k, reinterpret_cast(rhs_kxn), - reinterpret_cast(rhs_batch), rhs_stride); - - variant_call(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), - rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); + // RHS packing (single thread), then synchronize + if (ith == 0) { + memset(bias, 0, (size_t)n * sizeof(float)); + transpose_f32kxn_f16nxk((size_t)n, (size_t)k, + reinterpret_cast(rhs_kxn), + reinterpret_cast(rhs_batch_base), + rhs_stride); + + variant_call(kernels->rhs_info.pack_func, + /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr, + /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)), + rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr); } ggml_barrier(params->threadpool); - first_to_arrive.clear(std::memory_order_release); - - // Perform the matmul + // Matmul (threaded over n) { - const int64_t m_to_process = m; - const int64_t m_start = 0; - - const int64_t n_step = static_cast(kernel->get_n_step()); - int64_t num_threads = KAI_MIN(n / n_step, nth); - if (num_threads <= 0) { - num_threads = 1; + const int64_t n_step = (int64_t) kernel->get_n_step(); + int64_t num_threads_n = KAI_MIN(n / n_step, nth); + if (num_threads_n <= 0) { + num_threads_n = 1; } - if (ith < num_threads) { - const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); - const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; + if (ith < num_threads_n) { + const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step); + const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0; const int64_t n_start = ith * num_n_per_thread0; - const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; + const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0; - const size_t lhs_packed_offset = variant_call(kernel->get_lhs_offset, m_start, k); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k); - const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); + // LHS packed base at row 0 (consistent with packing above) + const size_t lhs_packed_offset0 = variant_call( + lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k); + const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); - const void * lhs_ptr = lhs_packed + lhs_packed_offset; + const void * lhs_ptr = lhs_packed + lhs_packed_offset0; const void * rhs_ptr = rhs_packed + rhs_packed_offset; - float * dst_ptr = reinterpret_cast(dst_batch + dst_offset); + float * dst_ptr = reinterpret_cast(dst_batch_base + dst_offset); - variant_call(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + variant_call(kernel->run_kernel, + (size_t)m, (size_t)n_to_process, (size_t)k, + lhs_ptr, rhs_ptr, + dst_ptr, dst_stride, sizeof(float), + -FLT_MAX, FLT_MAX); } } if (batch_idx != batch_size - 1) { - // This barrier is necessary when the batch size is larger than 1. While processing a batch, - // the work data buffer (params->wdata) is used as temporary storage which means that only - // a single batch can be processed at any given time. No barrier is needed for the last - // batch since GGML inserts a barrier between the execution of every operator. ggml_barrier(params->threadpool); } }