Skip to content

Commit 159c47f

Browse files
committed
Merge commit '335eb04a91f481f37c0c9b302ee31b449b04c3e9' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # CONTRIBUTING.md # Makefile # docs/build.md # examples/llama.swiftui/llama.swiftui/UI/ContentView.swift # examples/run/run.cpp # ggml/CMakeLists.txt # ggml/src/ggml-cpu/CMakeLists.txt # ggml/src/ggml-cuda/CMakeLists.txt # ggml/src/ggml-musa/CMakeLists.txt
2 parents ccd2dbe + 335eb04 commit 159c47f

File tree

20 files changed

+985
-30
lines changed

20 files changed

+985
-30
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ if (LLAMA_CUBLAS)
124124
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
125125

126126
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
127-
# 52 == lowest CUDA 12 standard
127+
# 50 == lowest CUDA 12 standard
128128
# 60 == f16 CUDA intrinsics
129129
# 61 == integer CUDA intrinsics
130130
# 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster
@@ -135,9 +135,9 @@ if (LLAMA_CUBLAS)
135135
message("CUDA Toolkit Version: ${CUDAToolkit_VERSION}")
136136
if(CUDAToolkit_VERSION VERSION_GREATER 12)
137137
add_compile_definitions(GGML_CUDA_USE_GRAPHS) #try enable cuda graphs on cu12 build
138-
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics
138+
set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics
139139
else()
140-
set(CMAKE_CUDA_ARCHITECTURES "37;52;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics
140+
set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics
141141
endif()
142142
endif()
143143
endif()

examples/llava/clip.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2811,9 +2811,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
28112811

28122812
if (!ctx->has_glm_projector) {
28132813
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
2814+
// The patches vector is used to get rows to index into the embeds with;
2815+
// we should skip dim 0 only if we have CLS to avoid going out of bounds
2816+
// when retrieving the rows.
2817+
int patch_offset = ctx->has_class_embedding ? 1 : 0;
28142818
int* patches_data = (int*)malloc(ggml_nbytes(patches));
28152819
for (int i = 0; i < num_patches; i++) {
2816-
patches_data[i] = i + 1;
2820+
patches_data[i] = i + patch_offset;
28172821
}
28182822
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
28192823
free(patches_data);
37 Bytes
Binary file not shown.

examples/server/utils.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
// increase max payload length to allow use of larger context size
99
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
10+
// disable Nagle's algorithm
11+
#define CPPHTTPLIB_TCP_NODELAY true
1012
#include "httplib.h"
1113

1214
// Change JSON_ASSERT from assert() to GGML_ASSERT:

examples/server/webui/src/components/ChatScreen.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ export default function ChatScreen() {
228228
value={inputMsg}
229229
onChange={(e) => setInputMsg(e.target.value)}
230230
onKeyDown={(e) => {
231+
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
231232
if (e.key === 'Enter' && e.shiftKey) return;
232233
if (e.key === 'Enter' && !e.shiftKey) {
233234
e.preventDefault();

examples/server/webui/src/utils/llama-vscode.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ export const useVSCodeContext = (
4040

4141
window.addEventListener('message', handleMessage);
4242
return () => window.removeEventListener('message', handleMessage);
43-
}, []);
43+
}, [inputRef, setInputMsg]);
4444

4545
// Add a keydown listener that sends the "escapePressed" message to the parent window
4646
useEffect(() => {

ggml/include/ggml-cpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ extern "C" {
9595
GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
9696
GGML_BACKEND_API int ggml_cpu_has_sve (void);
9797
GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
98+
GGML_BACKEND_API int ggml_cpu_has_sme (void);
9899
// other
99100
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
100101
GGML_BACKEND_API int ggml_cpu_has_vsx (void);

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5113,7 +5113,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
51135113

51145114
const int nb = n / QK_K;
51155115

5116-
#ifdef __ARM_NEON
5116+
#if defined(__ARM_FEATURE_SVE)
5117+
5118+
uint32_t utmp[4];
5119+
5120+
const int8_t m32 = 32;
5121+
const int vector_length = svcntb()*8;
5122+
const svuint8_t m3b_sv = svdup_n_u8(0x3);
5123+
const svint32_t vzero_sv = svdup_n_s32(0);
5124+
5125+
const svuint8_t m0_sv = svdup_n_u8(1);
5126+
const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
5127+
const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
5128+
const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
5129+
svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4));
5130+
5131+
float sum = 0;
5132+
5133+
for (int i = 0; i < nb; ++i) {
5134+
5135+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5136+
5137+
const uint8_t * restrict q3_sv = x[i].qs;
5138+
const uint8_t * restrict qh_sv = x[i].hmask;
5139+
const int8_t * restrict q8_sv = y[i].qs;
5140+
5141+
// Set up scales
5142+
uint32_t * aux = &x[i].scales;
5143+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
5144+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
5145+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
5146+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
5147+
5148+
int8_t * scale = (int8_t *)utmp;
5149+
5150+
for (int j = 0; j < 16; ++j) scale[j] -= m32;
5151+
5152+
switch (vector_length) {
5153+
case 128:
5154+
{
5155+
svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
5156+
svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
5157+
svuint8_t q3h_sv;
5158+
5159+
svint32_t sumi1_1 = svdup_n_s32(0);
5160+
svint8_t q3bytes_sv;
5161+
5162+
for (int j = 0; j < QK_K/128; ++j) {
5163+
5164+
const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
5165+
const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
5166+
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5167+
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5168+
5169+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
5170+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5171+
5172+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
5173+
5174+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
5175+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5176+
5177+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
5178+
5179+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5180+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5181+
5182+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
5183+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5184+
5185+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
5186+
5187+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
5188+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5189+
5190+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
5191+
5192+
5193+
scale += 4;
5194+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5195+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5196+
5197+
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
5198+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5199+
5200+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
5201+
5202+
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
5203+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5204+
5205+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
5206+
5207+
5208+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5209+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5210+
5211+
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
5212+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5213+
5214+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
5215+
5216+
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
5217+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5218+
5219+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
5220+
5221+
if (j == 0) {
5222+
qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
5223+
qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
5224+
}
5225+
5226+
scale += 4;
5227+
}
5228+
5229+
sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
5230+
} break;
5231+
case 256:
5232+
case 512:
5233+
{
5234+
svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
5235+
svuint8_t q3h_sv;
5236+
5237+
svint32_t sumi1_1 = svdup_n_s32(0);
5238+
svint8_t q3bytes_sv;
5239+
5240+
for (int j = 0; j < QK_K/128; ++j) {
5241+
5242+
const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
5243+
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5244+
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5245+
5246+
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
5247+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5248+
5249+
5250+
svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
5251+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
5252+
5253+
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
5254+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5255+
5256+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
5257+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
5258+
5259+
scale += 4;
5260+
q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5261+
q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5262+
5263+
q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
5264+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5265+
5266+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
5267+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
5268+
5269+
q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
5270+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5271+
5272+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
5273+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
5274+
5275+
if (j == 0) {
5276+
qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
5277+
}
5278+
5279+
scale += 4;
5280+
}
5281+
5282+
sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
5283+
} break;
5284+
default:
5285+
assert(false && "Unsupported vector length");
5286+
break;
5287+
}
5288+
}
5289+
*s = sum;
5290+
5291+
#elif __ARM_NEON
51175292

51185293
uint32_t aux[3];
51195294
uint32_t utmp[4];

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ struct ggml_arm_arch_features_type {
117117
int has_i8mm;
118118
int has_sve;
119119
int sve_cnt;
120-
} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
120+
int has_sme;
121+
} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
121122
#endif
122123

123124

@@ -2387,15 +2388,20 @@ bool ggml_is_numa(void) {
23872388
#define HWCAP2_I8MM (1 << 13)
23882389
#endif
23892390

2391+
#if !defined(HWCAP2_SME)
2392+
#define HWCAP2_SME (1 << 23)
2393+
#endif
2394+
23902395
static void ggml_init_arm_arch_features(void) {
23912396
#if defined(__linux__) && defined(__aarch64__)
23922397
uint32_t hwcap = getauxval(AT_HWCAP);
23932398
uint32_t hwcap2 = getauxval(AT_HWCAP2);
23942399

2395-
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
2400+
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
23962401
ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
2397-
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
2398-
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
2402+
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
2403+
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
2404+
ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
23992405

24002406
#if defined(__ARM_FEATURE_SVE)
24012407
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
@@ -2418,6 +2424,11 @@ static void ggml_init_arm_arch_features(void) {
24182424
}
24192425
ggml_arm_arch_features.has_i8mm = oldp;
24202426

2427+
if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
2428+
oldp = 0;
2429+
}
2430+
ggml_arm_arch_features.has_sme = oldp;
2431+
24212432
ggml_arm_arch_features.has_sve = 0;
24222433
ggml_arm_arch_features.sve_cnt = 0;
24232434
#else
@@ -2441,6 +2452,12 @@ static void ggml_init_arm_arch_features(void) {
24412452
ggml_arm_arch_features.has_sve = 0;
24422453
ggml_arm_arch_features.sve_cnt = 0;
24432454
#endif
2455+
2456+
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
2457+
ggml_arm_arch_features.has_sme = 1;
2458+
#else
2459+
ggml_arm_arch_features.has_sme = 0;
2460+
#endif
24442461
#endif
24452462
}
24462463
#endif
@@ -14487,6 +14504,14 @@ int ggml_cpu_get_sve_cnt(void) {
1448714504
#endif
1448814505
}
1448914506

14507+
int ggml_cpu_has_sme(void) {
14508+
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
14509+
return ggml_arm_arch_features.has_sme;
14510+
#else
14511+
return 0;
14512+
#endif
14513+
}
14514+
1449014515
void ggml_cpu_init(void) {
1449114516
// needed to initialize f16 tables
1449214517
{

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
#include "ggml-cpu-hbm.h"
1515
#endif
1616

17+
#ifdef GGML_USE_CPU_KLEIDIAI
18+
#include "kleidiai/kleidiai.h"
19+
#endif
20+
1721
#if defined(__APPLE__)
1822
#include <sys/types.h>
1923
#include <sys/sysctl.h>
@@ -39,6 +43,12 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
3943
// }
4044
// #endif
4145

46+
#ifdef GGML_USE_CPU_KLEIDIAI
47+
if (ggml_backend_cpu_kleidiai_buffer_type()) {
48+
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
49+
}
50+
#endif
51+
4252
#ifdef GGML_USE_CPU_AARCH64
4353
if (ggml_backend_cpu_aarch64_buffer_type()) {
4454
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
@@ -538,6 +548,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
538548
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
539549
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
540550
}
551+
if (ggml_cpu_has_sme()) {
552+
features.push_back({ "SME", "1" });
553+
}
541554
if (ggml_cpu_has_riscv_v()) {
542555
features.push_back({ "RISCV_V", "1" });
543556
}
@@ -559,6 +572,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
559572
#ifdef GGML_USE_OPENMP
560573
features.push_back({ "OPENMP", "1" });
561574
#endif
575+
#ifdef GGML_USE_CPU_KLEIDIAI
576+
features.push_back({ "KLEIDIAI", "1" });
577+
#endif
562578
#ifdef GGML_USE_CPU_AARCH64
563579
features.push_back({ "AARCH64_REPACK", "1" });
564580
#endif

0 commit comments

Comments
 (0)