Skip to content

Commit 9edd107

Browse files
committed
updates for review comments
1 parent 3e08f37 commit 9edd107

File tree

3 files changed

+114
-10
lines changed

3 files changed

+114
-10
lines changed

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
111111
function(check_arm_feature tag code)
112112
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
113113
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
114-
check_cxx_source_runs(
115-
"${code}"
116-
GGML_MACHINE_SUPPORTS_${tag}
117-
)
114+
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
118115
if (GGML_MACHINE_SUPPORTS_${tag})
119116
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
120-
elseif(NOT tag STREQUAL "sme")
121-
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
117+
else()
118+
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
119+
check_cxx_source_compiles("${code}" GGML_MACHINE_SUPPORTS_no${tag})
120+
if (GGML_MACHINE_SUPPORTS_no${tag})
121+
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
122+
endif()
122123
endif()
123124
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
124125
endfunction()
@@ -370,9 +371,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
370371
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
371372
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
372373

373-
string(FIND "${ARCH_FLAGS}" "+dotprod" DOTPROD_ENABLED)
374-
string(FIND "${ARCH_FLAGS}" "+i8mm" I8MM_ENABLED)
375-
string(FIND "${ARCH_FLAGS}" "+sme" SME_ENABLED)
374+
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
375+
if (NOT ARCH_FLAGS_TEMP)
376+
string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}")
377+
endif()
378+
string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
379+
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
380+
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
376381

377382
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
378383

ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "ggml-cpu.h"
2424
#include "ggml-impl.h"
2525
#include "ggml-backend-impl.h"
26+
#include "ggml-threading.h"
2627

2728
#include "kleidiai_kernels.h"
2829

@@ -35,6 +36,8 @@ struct ggml_kleidiai_context {
3536
} static ctx = { NULL };
3637

3738
static void init_kleidiai_context(void) {
39+
40+
ggml_critical_section_start();
3841
static bool initialized = false;
3942

4043
if (!initialized) {
@@ -55,6 +58,12 @@ static void init_kleidiai_context(void) {
5558
}
5659
ctx.kernels = ggml_kleidiai_select_kernels(features);
5760
}
61+
ggml_critical_section_end();
62+
}
63+
64+
static inline int ggml_ne(const ggml_tensor * tensor, int dim) {
65+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
66+
return tensor->ne[dim];
5867
}
5968

6069
namespace ggml::cpu::kleidiai {
@@ -237,7 +246,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
237246
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
238247
return false;
239248
}
240-
if (op->src[1]->type == GGML_TYPE_F32) {
249+
if (op->src[1]->type == GGML_TYPE_F32 &&
250+
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
241251
return true;
242252
}
243253
}

ggml/src/ggml-cpu/ggml-kleidiai/kleidiai_kernels.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
6363
/* .required_cpu = */ CPU_FEATURE_SME,
6464
},
6565
#endif
66+
#if defined(__APPLE__)
6667
#if defined(__ARM_FEATURE_DOTPROD)
6768
{
6869
/* DOTPROD GEMM */
@@ -149,6 +150,94 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
149150
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
150151
},
151152
#endif
153+
#else
154+
#if defined(__ARM_FEATURE_MATMUL_INT8)
155+
{
156+
/* i8mm GEMM */
157+
/* .kern_info = */ {
158+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
159+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
160+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
161+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
162+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
163+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
164+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
165+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
166+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
167+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
168+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
169+
},
170+
/* i8mm GEMV */
171+
/* .kern_info = */ {
172+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
173+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
174+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
175+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
176+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
177+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
178+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
179+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
180+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
181+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
182+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
183+
},
184+
/* .lhs_info = */ {
185+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
186+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
187+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
188+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
189+
},
190+
/* .rhs_info = */ {
191+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
192+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
193+
},
194+
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
195+
},
196+
#endif
197+
#if defined(__ARM_FEATURE_DOTPROD)
198+
{
199+
/* DOTPROD GEMM */
200+
/* .kern_info = */ {
201+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
202+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
203+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
204+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
205+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
206+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
207+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
208+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
209+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
210+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
211+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
212+
},
213+
/* DOTPROD GEMV */
214+
/* .kern_info = */ {
215+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
216+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
217+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
218+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
219+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
220+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
221+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
222+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
223+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
224+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
225+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
226+
},
227+
/* .lhs_info = */ {
228+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
229+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
230+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
231+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
232+
},
233+
/* .rhs_info = */ {
234+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
235+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
236+
},
237+
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
238+
},
239+
#endif
240+
#endif
152241
};
153242

154243
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {

0 commit comments

Comments
 (0)