2424
2525#define UNUSED GGML_UNUSED
2626
27+ static inline void decode_q4_Kx8_scales_mins (const uint8_t * scales_in,
28+ int16x8_t * out_mins,
29+ int8_t * out_scales) {
30+ constexpr uint32_t kmask1 = 0x3f3f3f3f ;
31+ constexpr uint32_t kmask2 = 0x0f0f0f0f ;
32+ constexpr uint32_t kmask3 = 0x03030303 ;
33+ constexpr uint8_t scales_size = 12 ;
34+
35+ uint32_t sm[3 ];
36+ memcpy (sm, scales_in, scales_size);
37+
38+ const uint32_t mins_0_3 = sm[1 ] & kmask1;
39+ const uint32_t mins_4_7 = ((sm[2 ] >> 4 ) & kmask2) | (((sm[1 ] >> 6 ) & kmask3) << 4 );
40+ const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
41+
42+ *out_mins = vreinterpretq_s16_u16 (vmovl_u8 (vreinterpret_u8_u32 (mins_u32)));
43+
44+ uint32_t scales_u32[2 ];
45+ scales_u32[0 ] = sm[0 ] & kmask1;
46+ scales_u32[1 ] = (sm[2 ] & kmask2) | (((sm[0 ] >> 6 ) & kmask3) << 4 );
47+ memcpy (out_scales, scales_u32, 8 );
48+ }
49+
50+
2751void ggml_quantize_mat_q8_0_4x4 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
2852 assert (QK8_0 == 32 );
2953 assert (k % QK8_0 == 0 );
@@ -1890,29 +1914,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
18901914 ggml_gemm_iq4_nl_4x4_q8_0_generic (n, s, bs, vx, vy, nr, nc);
18911915}
18921916
1893- static inline void decode_q4_Kx8_scales_mins (const uint8_t * scales_in,
1894- int16x8_t * out_mins,
1895- int8_t * out_scales) {
1896- constexpr uint32_t kmask1 = 0x3f3f3f3f ;
1897- constexpr uint32_t kmask2 = 0x0f0f0f0f ;
1898- constexpr uint32_t kmask3 = 0x03030303 ;
1899- constexpr uint8_t scales_size = 12 ;
1900-
1901- uint32_t sm[3 ];
1902- memcpy (sm, scales_in, scales_size);
1903-
1904- const uint32_t mins_0_3 = sm[1 ] & kmask1;
1905- const uint32_t mins_4_7 = ((sm[2 ] >> 4 ) & kmask2) | (((sm[1 ] >> 6 ) & kmask3) << 4 );
1906- const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
1907-
1908- *out_mins = vreinterpretq_s16_u16 (vmovl_u8 (vreinterpret_u8_u32 (mins_u32)));
1909-
1910- uint32_t scales_u32[2 ];
1911- scales_u32[0 ] = sm[0 ] & kmask1;
1912- scales_u32[1 ] = (sm[2 ] & kmask2) | (((sm[0 ] >> 6 ) & kmask3) << 4 );
1913- memcpy (out_scales, scales_u32, 8 );
1914- }
1915-
19161917
19171918void ggml_gemm_q4_K_8x8_q8_K (int n,
19181919 float * GGML_RESTRICT s,
@@ -1943,6 +1944,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
19431944 UNUSED (ncols_interleaved);
19441945 UNUSED (blocklen);
19451946
1947+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
19461948 const uint8x16_t m4b = vdupq_n_u8 (0x0f );
19471949
19481950 // 8 accumulators: 2 row pairs × 4 col pairs
@@ -1960,17 +1962,21 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
19601962
19611963 for (int b = 0 ; b < nb; b++) {
19621964 // bsums pairs belongs to the same q8_k subblock
1963- const int16x8_t y_bsums [4 ]{
1965+ const int16x8_t bsums [4 ]{
19641966 vpaddq_s16 (vld1q_s16 (q8_ptr[b].bsums + 16 * 0 ), vld1q_s16 (q8_ptr[b].bsums + 16 * 0 + 8 )),
19651967 vpaddq_s16 (vld1q_s16 (q8_ptr[b].bsums + 16 * 1 ), vld1q_s16 (q8_ptr[b].bsums + 16 * 1 + 8 )),
19661968 vpaddq_s16 (vld1q_s16 (q8_ptr[b].bsums + 16 * 2 ), vld1q_s16 (q8_ptr[b].bsums + 16 * 2 + 8 )),
19671969 vpaddq_s16 (vld1q_s16 (q8_ptr[b].bsums + 16 * 3 ), vld1q_s16 (q8_ptr[b].bsums + 16 * 3 + 8 )),
19681970 };
1971+ int16_t bsums_arr[4 ][8 ];
1972+ for (int q8_row = 0 ; q8_row < 4 ; q8_row++) {
1973+ vst1q_s16 (bsums_arr[q8_row], bsums[q8_row]);
1974+ }
19691975
19701976 int32x4_t sb_acc[4 ]; // Aux accumulators to store subblock (partial) results
19711977 int32x4_t acc[8 ]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
19721978 int32x4_t bias_acc[8 ]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
1973- for (int i = 0 ; i < 8 ; ++i ) {
1979+ for (int i = 0 ; i < 8 ; i++ ) {
19741980 acc[i] = vdupq_n_s32 (0 );
19751981 bias_acc[i] = vdupq_n_s32 (0 );
19761982 }
@@ -1992,7 +1998,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
19921998 int8x16_t q8_qs_23[8 ];
19931999
19942000 // Load 32-byte per row pair, 1 subblock each time
1995- for (int i = 0 ; i < 8 ; ++i ) {
2001+ for (int i = 0 ; i < 8 ; i++ ) {
19962002 const int offset = i * 32 ; // 16 for row 01, 16 for row 23
19972003 q8_qs_01[i] = vld1q_s8 (q8_base + offset);
19982004 q8_qs_23[i] = vld1q_s8 (q8_base + offset + 16 );
@@ -2007,7 +2013,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
20072013
20082014 // Q4s columns iterated in pairs (01, 23, 45, 67)
20092015 for (int cp = 0 ; cp < ncols_interleaved / 2 ; cp++) {
2010- for (int i = 0 ; i < 4 ; ++i ) {
2016+ for (int i = 0 ; i < 4 ; i++ ) {
20112017 sb_acc[i] = vdupq_n_s32 (0 );
20122018 }
20132019
@@ -2063,16 +2069,16 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
20632069 for (int q8_row = 0 ; q8_row < 4 ; q8_row++) {
20642070 // Each pair of subblocks share the same bsums
20652071 // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
2066- int16x8_t bsums_vec_lo = vdupq_n_s16 (y_bsums [sb][q8_row * 2 ]);
2067- int16x8_t bsums_vec_hi = vdupq_n_s16 (y_bsums [sb][q8_row * 2 + 1 ]);
2072+ int16x4_t bsums_vec_lo = vdup_n_s16 (bsums_arr [sb][q8_row * 2 ]);
2073+ int16x4_t bsums_vec_hi = vdup_n_s16 (bsums_arr [sb][q8_row * 2 + 1 ]);
20682074
20692075 bias_acc[2 * q8_row] =
2070- vmlal_s16 (bias_acc[2 * q8_row], vget_low_s16 ( bsums_vec_lo) , vget_low_s16 (q4sb_mins[0 ]));
2076+ vmlal_s16 (bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16 (q4sb_mins[0 ]));
20712077 bias_acc[2 * q8_row] =
2072- vmlal_s16 (bias_acc[2 * q8_row], vget_low_s16 ( bsums_vec_hi) , vget_low_s16 (q4sb_mins[1 ]));
2073- bias_acc[2 * q8_row + 1 ] = vmlal_s16 (bias_acc[2 * q8_row + 1 ], vget_high_s16 ( bsums_vec_lo) ,
2078+ vmlal_s16 (bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16 (q4sb_mins[1 ]));
2079+ bias_acc[2 * q8_row + 1 ] = vmlal_s16 (bias_acc[2 * q8_row + 1 ], bsums_vec_lo,
20742080 vget_high_s16 (q4sb_mins[0 ]));
2075- bias_acc[2 * q8_row + 1 ] = vmlal_s16 (bias_acc[2 * q8_row + 1 ], vget_high_s16 ( bsums_vec_hi) ,
2081+ bias_acc[2 * q8_row + 1 ] = vmlal_s16 (bias_acc[2 * q8_row + 1 ], bsums_vec_hi,
20762082 vget_high_s16 (q4sb_mins[1 ]));
20772083 }
20782084 } // for sb
@@ -2095,19 +2101,13 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
20952101
20962102 for (int i = 0 ; i < q8_k_blocklen; i++) {
20972103 for (int j = 0 ; j < 2 ; j++) {
2098- const float32x4_t dmins = {
2099- q8_ptr[b].d [i] * GGML_CPU_FP16_TO_FP32 (q4_ptr[b].dmin [j * 4 + 0 ]),
2100- q8_ptr[b].d [i] * GGML_CPU_FP16_TO_FP32 (q4_ptr[b].dmin [j * 4 + 1 ]),
2101- q8_ptr[b].d [i] * GGML_CPU_FP16_TO_FP32 (q4_ptr[b].dmin [j * 4 + 2 ]),
2102- q8_ptr[b].d [i] * GGML_CPU_FP16_TO_FP32 (q4_ptr[b].dmin [j * 4 + 3 ]),
2103- };
2104+ // TODO: Change to a single vmul
2105+ float32x4_t q8_d = vdupq_n_f32 (q8_ptr[b].d [i]);
2106+ float32x4_t q4_dmin = vcvt_f32_f16 (vld1_f16 ((const __fp16 *)(q4_ptr[b].dmin + j * 4 )));
2107+ const float32x4_t dmins = vmulq_f32 (q4_dmin, q8_d);
21042108
2105- const float32x4_t scale = {
2106- q8_ptr[b].d [i] * GGML_CPU_FP16_TO_FP32 (q4_ptr[b].d [j * 4 + 0 ]),
2107- q8_ptr[b].d [i] * GGML_CPU_FP16_TO_FP32 (q4_ptr[b].d [j * 4 + 1 ]),
2108- q8_ptr[b].d [i] * GGML_CPU_FP16_TO_FP32 (q4_ptr[b].d [j * 4 + 2 ]),
2109- q8_ptr[b].d [i] * GGML_CPU_FP16_TO_FP32 (q4_ptr[b].d [j * 4 + 3 ]),
2110- };
2109+ float32x4_t q4_d = vcvt_f32_f16 (vld1_f16 ((const __fp16 *)(q4_ptr[b].d + j * 4 )));
2110+ const float32x4_t scale = vmulq_f32 (q4_d, q8_d);
21112111
21122112 acc_f32[2 * i + j] = vmlsq_f32 (acc_f32[2 * i + j], vcvtq_f32_s32 (bias_acc[2 * i + j]), dmins);
21132113 acc_f32[2 * i + j] =
@@ -2127,5 +2127,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
21272127 }
21282128 } // for x
21292129 } // for y
2130+ return ;
2131+ #endif
2132+ ggml_gemm_q4_K_8x8_q8_K_generic (n, s, bs, vx, vy, nr, nc);
21302133}
21312134
0 commit comments