2525#include " kernels.h"
2626
2727#define NELEMS (x ) sizeof (x) / sizeof (*x)
28+
29+ static const size_t INT4_PER_BYTE = 2 ;
30+ static const size_t INT4_BITS = 4 ;
31+ static const int Q4_0_ZERO_POINT = 8 ;
32+ const size_t INT4_PER_UINT16 = 4 ;
33+
34+ static inline float compute_fp16_to_fp32 (ggml_fp16_t h) {
35+ static_assert (sizeof (ggml_fp16_t ) == sizeof (__fp16), " ggml_fp16_t and __fp16 must be the same size" );
36+ __fp16 tmp;
37+ memcpy (&tmp, &h, sizeof (ggml_fp16_t ));
38+ return (float )tmp;
39+ }
40+
41+ static void dequantize_row_qsi4c32pscalef16 (
42+ const void *packed_data,
43+ int32_t row_idx,
44+ int64_t nc,
45+ float *out,
46+ size_t nr_pack,
47+ size_t packed_row_stride,
48+ size_t kr,
49+ size_t bl,
50+ size_t num_bytes_multiplier
51+ ) {
52+ size_t group_idx = row_idx / nr_pack;
53+ size_t row_in_group = row_idx % nr_pack;
54+ const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
55+ size_t num_blocks = nc / bl;
56+ const uint8_t *block_ptr = packed_group;
57+
58+ for (size_t b = 0 ; b < num_blocks; ++b) {
59+ uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
60+ float scale = compute_fp16_to_fp32 (scale_f16);
61+
62+ const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
63+ size_t num_segments = bl / kr;
64+ size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
65+
66+ for (size_t s = 0 ; s < num_segments; ++s) {
67+ const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
68+ const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
69+ for (size_t k = 0 ; k < num_bytes_per_segment; ++k) {
70+ uint8_t byte = qbytes[k] ^ 0x88 ;
71+ int x0 = (byte & 0x0F ) - Q4_0_ZERO_POINT;
72+ int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
73+ out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
74+ out[b * bl + s * num_bytes_per_segment + k + bl/2 ] = x1 * scale;
75+ }
76+ }
77+ block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
78+ }
79+ }
80+
81+ static void dequantize_row_qsi4c32ps1s0scalef16 (
82+ const void *packed_data,
83+ int32_t row_idx,
84+ int64_t k,
85+ float *out,
86+ size_t nr,
87+ size_t packed_row_stride,
88+ size_t kr,
89+ size_t bl,
90+ size_t num_bytes_multiplier
91+ ) {
92+ const size_t num_blocks = k / bl;
93+ const size_t bl4 = bl / INT4_PER_UINT16;
94+
95+ size_t group_idx = row_idx / nr;
96+ size_t row_in_group = row_idx % nr;
97+
98+ const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
99+ const uint16_t *qdata = (const uint16_t *)packed_group;
100+ const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
101+
102+ for (size_t block_idx = 0 ; block_idx < num_blocks; ++block_idx) {
103+ uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
104+ float scale = compute_fp16_to_fp32 (scale_f16);
105+
106+ for (size_t bl4_idx = 0 ; bl4_idx < bl4; ++bl4_idx) {
107+ uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
108+
109+ for (size_t qidx = 0 ; qidx < INT4_PER_UINT16; ++qidx) {
110+ int v = ((q >> (qidx * 4 )) & 0xF ) - Q4_0_ZERO_POINT;
111+ out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
112+ }
113+ }
114+ }
115+ GGML_UNUSED (kr);
116+ }
117+
28118static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
29119#if defined(__ARM_FEATURE_SME)
30120 {
@@ -63,8 +153,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
63153 /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
64154 },
65155 /* .rhs_info = */ {
66- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
67- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
156+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
157+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
158+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
159+ /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
68160 },
69161 /* .required_cpu = */ CPU_FEATURE_SME,
70162 /* .lhs_type = */ GGML_TYPE_F32,
@@ -107,8 +199,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
107199 /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
108200 },
109201 /* .rhs_info = */ {
110- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
111- /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
202+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
203+ /* .packed_stride = */ NULL ,
204+ /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
205+ /* .to_float = */ NULL ,
112206 },
113207 /* .required_cpu = */ CPU_FEATURE_SME,
114208 /* .lhs_type = */ GGML_TYPE_F32,
@@ -154,8 +248,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
154248 /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
155249 },
156250 /* .rhs_info = */ {
157- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
158- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
251+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
252+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
253+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
254+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
159255 },
160256 /* .required_cpu = */ CPU_FEATURE_DOTPROD,
161257 /* .lhs_type = */ GGML_TYPE_F32,
@@ -200,8 +296,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
200296 /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
201297 },
202298 /* .rhs_info = */ {
203- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
204- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
299+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
300+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
301+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
302+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
205303 },
206304 /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
207305 /* .lhs_type = */ GGML_TYPE_F32,
@@ -247,8 +345,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
247345 /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
248346 },
249347 /* .rhs_info = */ {
250- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
251- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
348+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
349+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
350+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
351+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
252352 },
253353 /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
254354 /* .lhs_type = */ GGML_TYPE_F32,
@@ -293,8 +393,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
293393 /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
294394 },
295395 /* .rhs_info = */ {
296- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
297- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
396+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
397+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
398+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
399+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
298400 },
299401 /* .required_cpu = */ CPU_FEATURE_DOTPROD,
300402 /* .lhs_type = */ GGML_TYPE_F32,
@@ -305,6 +407,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
305407#endif
306408};
307409
410+ const char * cpu_feature_to_string (cpu_feature f) {
411+ switch (f) {
412+ case CPU_FEATURE_NONE: return " NONE" ;
413+ case CPU_FEATURE_DOTPROD: return " DOTPROD" ;
414+ case CPU_FEATURE_I8MM: return " I8MM" ;
415+ case CPU_FEATURE_SVE: return " SVE" ;
416+ case CPU_FEATURE_SME: return " SME" ;
417+ default : return " UNKNOWN" ;
418+ }
419+ }
420+
308421ggml_kleidiai_kernels * ggml_kleidiai_select_kernels (cpu_feature cpu_features, const ggml_tensor * tensor) {
309422 ggml_kleidiai_kernels * kernel = nullptr ;
310423
0 commit comments