Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
project("ggml" C CXX)
project("ggml" C CXX ASM)
include(CheckIncludeFileCXX)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
Expand Down
8 changes: 5 additions & 3 deletions ggml/src/ggml-cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)

# Fetch KleidiAI sources:
include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.11.0")
set(KLEIDIAI_COMMIT_TAG "v1.13.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1")

if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
Expand Down Expand Up @@ -555,6 +555,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)

list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
Expand All @@ -576,7 +577,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
endif()

Expand Down
49 changes: 43 additions & 6 deletions ggml/src/ggml-cpu/kleidiai/kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"

#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
Expand Down Expand Up @@ -127,6 +128,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
},
/* SME GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
Expand All @@ -141,7 +148,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
},
/* .lhs_info = */ {
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
Expand Down Expand Up @@ -173,6 +180,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
},
/* SME GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
Expand All @@ -187,7 +200,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
},
/* .lhs_info = */ {
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
Expand Down Expand Up @@ -222,6 +235,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* DOTPROD GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
Expand All @@ -236,7 +255,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
},
/* .lhs_info = */ {
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
Expand Down Expand Up @@ -270,6 +289,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
},
/* i8mm GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
Expand All @@ -284,7 +309,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
},
/* .lhs_info = */ {
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
Expand Down Expand Up @@ -319,6 +344,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
},
/* i8mm GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
Expand All @@ -333,7 +364,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
},
/* .lhs_info = */ {
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
Expand Down Expand Up @@ -367,6 +398,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* DOTPROD GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
Expand All @@ -381,7 +418,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
},
/* .lhs_info = */ {
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-cpu/kleidiai/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ struct rhs_packing_info {

struct ggml_kleidiai_kernels {
kernel_info gemm;
lhs_packing_info gemm_lhs_info;

kernel_info gemv;
lhs_packing_info lhs_info;
lhs_packing_info gemv_lhs_info;

rhs_packing_info rhs_info;

cpu_feature required_cpu;
Expand Down
23 changes: 14 additions & 9 deletions ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
}
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
GGML_ASSERT(kernels);
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
bool is_gemv = op->src[1]->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;

size_t k = op->src[0]->ne[0];
size_t n = op->src[0]->ne[1];
Expand All @@ -134,9 +136,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
size_t sr = kernel->get_sr();

if (kernels->rhs_type == GGML_TYPE_Q4_0) {
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
k * n * sizeof(float) + n * sizeof(float);
} else {
Expand Down Expand Up @@ -173,7 +175,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);

kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel);

const int nth = params->nth;
Expand All @@ -198,7 +202,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
const int64_t sr = static_cast<int64_t>(kernel->get_sr());

const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
const size_t kxn_size = k * n * sizeof(float);
const size_t bias_size = n * sizeof(float);
Expand Down Expand Up @@ -229,12 +233,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;

const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);

const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;

variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
}
}

Expand Down Expand Up @@ -306,8 +310,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);

kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = &kernels->lhs_info;
bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;

GGML_ASSERT(kernel);

Expand Down
Loading