|
22 | 22 |
|
23 | 23 | #include "kai_common.h" |
24 | 24 |
|
| 25 | +#include "simd-mappings.h" |
| 26 | + |
25 | 27 | #include "kernels.h" |
26 | 28 |
|
27 | 29 | #define NELEMS(x) sizeof(x) / sizeof(*x) |
| 30 | + |
| 31 | +static const size_t INT4_PER_BYTE = 2; |
| 32 | +static const size_t INT4_BITS = 4; |
| 33 | +static const int Q4_0_ZERO_POINT = 8; |
| 34 | +const size_t INT4_PER_UINT16 = 4; |
| 35 | + |
| 36 | +static void dequantize_row_qsi4c32pscalef16( |
| 37 | + const void *packed_data, |
| 38 | + int32_t row_idx, |
| 39 | + int64_t nc, |
| 40 | + float *out, |
| 41 | + size_t nr_pack, |
| 42 | + size_t packed_row_stride, |
| 43 | + size_t kr, |
| 44 | + size_t bl, |
| 45 | + size_t num_bytes_multiplier |
| 46 | +) { |
| 47 | + size_t group_idx = row_idx / nr_pack; |
| 48 | + size_t row_in_group = row_idx % nr_pack; |
| 49 | + const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; |
| 50 | + size_t num_blocks = nc / bl; |
| 51 | + const uint8_t *block_ptr = packed_group; |
| 52 | + |
| 53 | + for (size_t b = 0; b < num_blocks; ++b) { |
| 54 | + uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier)); |
| 55 | + float scale = GGML_CPU_FP16_TO_FP32(scale_f16); |
| 56 | + |
| 57 | + const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier; |
| 58 | + size_t num_segments = bl / kr; |
| 59 | + size_t num_bytes_per_segment = kr / INT4_PER_BYTE; |
| 60 | + |
| 61 | + for (size_t s = 0; s < num_segments; ++s) { |
| 62 | + const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment; |
| 63 | + const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment; |
| 64 | + for (size_t k = 0; k < num_bytes_per_segment; ++k) { |
| 65 | + uint8_t byte = qbytes[k] ^ 0x88; |
| 66 | + int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT; |
| 67 | + int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT; |
| 68 | + out[b * bl + s * num_bytes_per_segment + k] = x0 * scale; |
| 69 | + out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale; |
| 70 | + } |
| 71 | + } |
| 72 | + block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment; |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +static void dequantize_row_qsi4c32ps1s0scalef16( |
| 77 | + const void *packed_data, |
| 78 | + int32_t row_idx, |
| 79 | + int64_t k, |
| 80 | + float *out, |
| 81 | + size_t nr, |
| 82 | + size_t packed_row_stride, |
| 83 | + size_t kr, |
| 84 | + size_t bl, |
| 85 | + size_t num_bytes_multiplier |
| 86 | +) { |
| 87 | + const size_t num_blocks = k / bl; |
| 88 | + const size_t bl4 = bl / INT4_PER_UINT16; |
| 89 | + |
| 90 | + size_t group_idx = row_idx / nr; |
| 91 | + size_t row_in_group = row_idx % nr; |
| 92 | + |
| 93 | + const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; |
| 94 | + const uint16_t *qdata = (const uint16_t *)packed_group; |
| 95 | + const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier)); |
| 96 | + |
| 97 | + for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) { |
| 98 | + uint16_t scale_f16 = scales[row_in_group + block_idx * nr]; |
| 99 | + float scale = GGML_CPU_FP16_TO_FP32(scale_f16); |
| 100 | + |
| 101 | + for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) { |
| 102 | + uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group]; |
| 103 | + |
| 104 | + for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) { |
| 105 | + int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT; |
| 106 | + out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale; |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | + GGML_UNUSED(kr); |
| 111 | +} |
| 112 | + |
28 | 113 | static ggml_kleidiai_kernels gemm_gemv_kernels[] = { |
29 | 114 | #if defined(__ARM_FEATURE_SME) |
30 | 115 | { |
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { |
63 | 148 | /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, |
64 | 149 | }, |
65 | 150 | /* .rhs_info = */ { |
66 | | - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, |
67 | | - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, |
| 151 | + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, |
| 152 | + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, |
| 153 | + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, |
| 154 | + /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, |
68 | 155 | }, |
69 | 156 | /* .required_cpu = */ CPU_FEATURE_SME, |
70 | 157 | /* .lhs_type = */ GGML_TYPE_F32, |
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { |
107 | 194 | /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, |
108 | 195 | }, |
109 | 196 | /* .rhs_info = */ { |
110 | | - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, |
111 | | - /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, |
| 197 | + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, |
| 198 | + /* .packed_stride = */ NULL, |
| 199 | + /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, |
| 200 | + /* .to_float = */ NULL, |
112 | 201 | }, |
113 | 202 | /* .required_cpu = */ CPU_FEATURE_SME, |
114 | 203 | /* .lhs_type = */ GGML_TYPE_F32, |
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { |
154 | 243 | /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, |
155 | 244 | }, |
156 | 245 | /* .rhs_info = */ { |
157 | | - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
158 | | - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 246 | + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 247 | + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 248 | + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 249 | + /* .to_float = */ dequantize_row_qsi4c32pscalef16, |
159 | 250 | }, |
160 | 251 | /* .required_cpu = */ CPU_FEATURE_DOTPROD, |
161 | 252 | /* .lhs_type = */ GGML_TYPE_F32, |
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { |
200 | 291 | /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, |
201 | 292 | }, |
202 | 293 | /* .rhs_info = */ { |
203 | | - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
204 | | - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 294 | + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 295 | + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 296 | + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 297 | + /* .to_float = */ dequantize_row_qsi4c32pscalef16, |
205 | 298 | }, |
206 | 299 | /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, |
207 | 300 | /* .lhs_type = */ GGML_TYPE_F32, |
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { |
247 | 340 | /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, |
248 | 341 | }, |
249 | 342 | /* .rhs_info = */ { |
250 | | - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
251 | | - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 343 | + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 344 | + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 345 | + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 346 | + /* .to_float = */ dequantize_row_qsi4c32pscalef16, |
252 | 347 | }, |
253 | 348 | /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, |
254 | 349 | /* .lhs_type = */ GGML_TYPE_F32, |
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { |
293 | 388 | /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, |
294 | 389 | }, |
295 | 390 | /* .rhs_info = */ { |
296 | | - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
297 | | - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 391 | + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 392 | + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 393 | + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, |
| 394 | + /* .to_float = */ dequantize_row_qsi4c32pscalef16, |
298 | 395 | }, |
299 | 396 | /* .required_cpu = */ CPU_FEATURE_DOTPROD, |
300 | 397 | /* .lhs_type = */ GGML_TYPE_F32, |
|
0 commit comments