Skip to content

Commit 3f86ed8

Browse files
committed
kleidiai: add support for get_rows
1 parent 0d92267 commit 3f86ed8

File tree

4 files changed

+202
-24
lines changed

4 files changed

+202
-24
lines changed

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
494494

495495
# Fetch KleidiAI sources:
496496
include(FetchContent)
497-
set(KLEIDIAI_COMMIT_TAG "v1.9.0")
497+
set(KLEIDIAI_COMMIT_TAG "v1.11.0")
498498
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
499-
set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
499+
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
500500

501501
if (POLICY CMP0135)
502502
cmake_policy(SET CMP0135 NEW)

ggml/src/ggml-cpu/kleidiai/kernels.cpp

Lines changed: 125 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,96 @@
2525
#include "kernels.h"
2626

2727
#define NELEMS(x) sizeof(x) / sizeof(*x)
28+
29+
static const size_t INT4_PER_BYTE = 2;
30+
static const size_t INT4_BITS = 4;
31+
static const int Q4_0_ZERO_POINT = 8;
32+
const size_t INT4_PER_UINT16 = 4;
33+
34+
static inline float compute_fp16_to_fp32(ggml_fp16_t h) {
35+
static_assert(sizeof(ggml_fp16_t) == sizeof(__fp16), "ggml_fp16_t and __fp16 must be the same size");
36+
__fp16 tmp;
37+
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
38+
return (float)tmp;
39+
}
40+
41+
static void dequantize_row_qsi4c32pscalef16(
42+
const void *packed_data,
43+
int32_t row_idx,
44+
int64_t nc,
45+
float *out,
46+
size_t nr_pack,
47+
size_t packed_row_stride,
48+
size_t kr,
49+
size_t bl,
50+
size_t num_bytes_multiplier
51+
) {
52+
size_t group_idx = row_idx / nr_pack;
53+
size_t row_in_group = row_idx % nr_pack;
54+
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
55+
size_t num_blocks = nc / bl;
56+
const uint8_t *block_ptr = packed_group;
57+
58+
for (size_t b = 0; b < num_blocks; ++b) {
59+
uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
60+
float scale = compute_fp16_to_fp32(scale_f16);
61+
62+
const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
63+
size_t num_segments = bl / kr;
64+
size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
65+
66+
for (size_t s = 0; s < num_segments; ++s) {
67+
const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
68+
const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
69+
for (size_t k = 0; k < num_bytes_per_segment; ++k) {
70+
uint8_t byte = qbytes[k] ^ 0x88;
71+
int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
72+
int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
73+
out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
74+
out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
75+
}
76+
}
77+
block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
78+
}
79+
}
80+
81+
static void dequantize_row_qsi4c32ps1s0scalef16(
82+
const void *packed_data,
83+
int32_t row_idx,
84+
int64_t k,
85+
float *out,
86+
size_t nr,
87+
size_t packed_row_stride,
88+
size_t kr,
89+
size_t bl,
90+
size_t num_bytes_multiplier
91+
) {
92+
const size_t num_blocks = k / bl;
93+
const size_t bl4 = bl / INT4_PER_UINT16;
94+
95+
size_t group_idx = row_idx / nr;
96+
size_t row_in_group = row_idx % nr;
97+
98+
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
99+
const uint16_t *qdata = (const uint16_t *)packed_group;
100+
const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
101+
102+
for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
103+
uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
104+
float scale = compute_fp16_to_fp32(scale_f16);
105+
106+
for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
107+
uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
108+
109+
for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
110+
int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
111+
out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
112+
}
113+
}
114+
}
115+
GGML_UNUSED(kr);
116+
}
117+
28118
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
29119
#if defined(__ARM_FEATURE_SME)
30120
{
@@ -63,8 +153,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
63153
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
64154
},
65155
/* .rhs_info = */ {
66-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
67-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
156+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
157+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
158+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
159+
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
68160
},
69161
/* .required_cpu = */ CPU_FEATURE_SME,
70162
/* .lhs_type = */ GGML_TYPE_F32,
@@ -107,8 +199,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
107199
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
108200
},
109201
/* .rhs_info = */ {
110-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
111-
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
202+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
203+
/* .packed_stride = */ NULL,
204+
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
205+
/* .to_float = */ NULL,
112206
},
113207
/* .required_cpu = */ CPU_FEATURE_SME,
114208
/* .lhs_type = */ GGML_TYPE_F32,
@@ -154,8 +248,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
154248
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
155249
},
156250
/* .rhs_info = */ {
157-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
158-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
251+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
252+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
253+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
254+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
159255
},
160256
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
161257
/* .lhs_type = */ GGML_TYPE_F32,
@@ -200,8 +296,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
200296
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
201297
},
202298
/* .rhs_info = */ {
203-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
204-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
299+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
300+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
301+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
302+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
205303
},
206304
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
207305
/* .lhs_type = */ GGML_TYPE_F32,
@@ -247,8 +345,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
247345
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
248346
},
249347
/* .rhs_info = */ {
250-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
251-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
348+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
349+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
350+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
351+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
252352
},
253353
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
254354
/* .lhs_type = */ GGML_TYPE_F32,
@@ -293,8 +393,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
293393
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
294394
},
295395
/* .rhs_info = */ {
296-
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
297-
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
396+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
397+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
398+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
399+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
298400
},
299401
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
300402
/* .lhs_type = */ GGML_TYPE_F32,
@@ -305,6 +407,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
305407
#endif
306408
};
307409

410+
const char* cpu_feature_to_string(cpu_feature f) {
411+
switch (f) {
412+
case CPU_FEATURE_NONE: return "NONE";
413+
case CPU_FEATURE_DOTPROD: return "DOTPROD";
414+
case CPU_FEATURE_I8MM: return "I8MM";
415+
case CPU_FEATURE_SVE: return "SVE";
416+
case CPU_FEATURE_SME: return "SME";
417+
default: return "UNKNOWN";
418+
}
419+
}
420+
308421
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
309422
ggml_kleidiai_kernels * kernel = nullptr;
310423

ggml/src/ggml-cpu/kleidiai/kernels.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ struct rhs_packing_info {
7171
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
7272
std::function<size_t(size_t n, size_t k)>
7373
> packed_size;
74+
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
7475
std::variant<
7576
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
7677
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
7778
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
7879
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
7980
> pack_func;
81+
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
82+
size_t kr, size_t bl, size_t num_bytes_multiplier);
8083
};
8184

8285
struct ggml_kleidiai_kernels {
@@ -93,3 +96,4 @@ struct ggml_kleidiai_kernels {
9396

9497
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
9598
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
99+
const char* cpu_feature_to_string(cpu_feature features);

ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ static void init_kleidiai_context(void) {
6262
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
6363
}
6464
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
65+
#ifndef NDEBUG
66+
if (ctx.kernels) {
67+
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
68+
}
69+
#endif
6570
}
6671
ggml_critical_section_end();
6772
}
@@ -102,6 +107,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
102107

103108
class tensor_traits : public ggml::cpu::tensor_traits {
104109
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
110+
if (op->op != GGML_OP_MUL_MAT) {
111+
return false;
112+
}
105113
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
106114
GGML_ASSERT(kernels);
107115
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
@@ -135,6 +143,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
135143
} else if (dst->src[0]->type == GGML_TYPE_F16) {
136144
return compute_forward_kv_cache(params, dst);
137145
}
146+
} else if (dst->op == GGML_OP_GET_ROWS) {
147+
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
148+
return compute_forward_get_rows(params, dst);
149+
}
138150
}
139151
return false;
140152
}
@@ -342,6 +354,45 @@ class tensor_traits : public ggml::cpu::tensor_traits {
342354
return true;
343355
}
344356

357+
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
358+
GGML_ASSERT(ctx.kernels);
359+
360+
const ggml_tensor * src0 = dst->src[0];
361+
const ggml_tensor * src1 = dst->src[1];
362+
363+
GGML_TENSOR_BINARY_OP_LOCALS
364+
365+
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
366+
kernel_info * kernel = &ctx.kernels->gemm;
367+
368+
const int64_t nc = ne00;
369+
const int64_t nr = ggml_nelements(src1);
370+
371+
const size_t block_rows = kernel->get_nr();
372+
const size_t kr = kernel->get_kr();
373+
374+
const size_t num_bytes_multiplier = sizeof(uint16_t);
375+
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
376+
377+
const int ith = params->ith;
378+
const int nth = params->nth;
379+
380+
const int dr = (nr + nth - 1) / nth;
381+
const int ir0 = dr * ith;
382+
const int ir1 = MIN(ir0 + dr, nr);
383+
384+
for (int64_t i = ir0; i < ir1; ++i) {
385+
int32_t row_idx = ((const int32_t *)src1->data)[i];
386+
GGML_ASSERT(row_idx >= 0 && row_idx < (int32_t)src0->ne[1]);
387+
388+
float *out = (float *)((char *)dst->data + i * nb1);
389+
390+
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
391+
}
392+
393+
return true;
394+
}
395+
345396
public:
346397
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
347398
GGML_ASSERT(ctx.kernels);
@@ -351,17 +402,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
351402
size_t kr = ctx.kernels->gemm.get_kr();
352403
size_t sr = ctx.kernels->gemm.get_sr();
353404

354-
#ifndef NDEBUG
355-
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
356-
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
357-
#endif
358405
struct kai_rhs_pack_qs4cxs1s0_param params;
359406
params.lhs_zero_point = 1;
360407
params.rhs_zero_point = 8;
361408
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
362409

363410
return 0;
364-
365411
GGML_UNUSED(data_size);
366412
}
367413
};
@@ -375,8 +421,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
375421
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
376422
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
377423

378-
GGML_UNUSED(buffer);
379424
return GGML_STATUS_SUCCESS;
425+
GGML_UNUSED(buffer);
380426
}
381427

382428
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +464,33 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
418464
GGML_UNUSED(buft);
419465
}
420466

467+
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
468+
GGML_ASSERT(ctx.kernels);
469+
const size_t n = tensor->ne[1];
470+
const size_t k = tensor->ne[0];
471+
size_t nr = ctx.kernels->gemm.get_nr();
472+
size_t kr = ctx.kernels->gemm.get_kr();
473+
474+
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
475+
476+
GGML_UNUSED(buft);
477+
}
478+
421479
namespace ggml::cpu::kleidiai {
422480
class extra_buffer_type : ggml::cpu::extra_buffer_type {
423481
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
424-
if (op->op == GGML_OP_MUL_MAT &&
482+
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
425483
op->src[0]->type == GGML_TYPE_Q4_0 &&
426484
op->src[0]->buffer &&
427485
(ggml_n_dims(op->src[0]) == 2) &&
428486
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
487+
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
488+
return false;
489+
}
429490
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
430491
return false;
431492
}
432-
if (op->src[1]->type == GGML_TYPE_F32 &&
493+
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
433494
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
434495
return true;
435496
}
@@ -438,7 +499,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
438499
}
439500

440501
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
441-
if (op->op == GGML_OP_MUL_MAT) {
502+
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
442503
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
443504
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
444505
}
@@ -469,7 +530,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
469530
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
470531
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
471532
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
472-
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
533+
/* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
473534
/* .is_host = */ nullptr,
474535
},
475536
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),

0 commit comments

Comments
 (0)