@@ -498,6 +498,165 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
498498 ggml_gemv_iq4_nl_4x4_q8_0_generic (n, s, bs, vx, vy, nr, nc);
499499}
500500
501+ void ggml_gemv_q4_K_8x8_q8_K (int n,
502+ float * GGML_RESTRICT s,
503+ size_t bs,
504+ const void * GGML_RESTRICT vx,
505+ const void * GGML_RESTRICT vy,
506+ int nr,
507+ int nc) {
508+ constexpr int qk = QK_K;
509+ const int nb = n / qk;
510+
511+ constexpr int ncols_interleaved = 8 ;
512+ constexpr int blocklen = 8 ;
513+
514+ assert (n % qk == 0 );
515+ assert (nr % 4 == 0 );
516+ assert (nc % ncols_interleaved == 0 );
517+
518+ UNUSED (s);
519+ UNUSED (bs);
520+ UNUSED (vx);
521+ UNUSED (vy);
522+ UNUSED (nr);
523+ UNUSED (nc);
524+ UNUSED (nb);
525+ UNUSED (ncols_interleaved);
526+ UNUSED (blocklen);
527+
528+ #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
529+ constexpr int col_pairs = ncols_interleaved / 2 ;
530+ const uint8x16_t m4b = vdupq_n_u8 (0x0f );
531+
532+ // 1x8 tile = 2 x 4
533+ float32x4_t acc_f32[ncols_interleaved / 4 ];
534+
535+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
536+
537+ for (int x = 0 ; x < nc / ncols_interleaved; x++) {
538+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
539+
540+ for (int i = 0 ; i < ncols_interleaved / 4 ; i++) {
541+ acc_f32[i] = vdupq_n_f32 (0 );
542+ }
543+
544+ for (int b = 0 ; b < nb; b++) {
545+ float32x4_t q4_d_0 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *)q4_ptr[b].d )); // d0 d1 d2 d3
546+ float32x4_t q4_d_1 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *)q4_ptr[b].d + 4 )); // d4 d5 d6 d7
547+ float32x4_t q8_d = vdupq_n_f32 (q8_ptr[b].d );
548+ float32x4_t sb_scale_0 = vmulq_f32 (q4_d_0, q8_d);
549+ float32x4_t sb_scale_1 = vmulq_f32 (q4_d_1, q8_d);
550+ float32x4_t q4_dmin_0 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) q4_ptr[b].dmin )); // dmin 0..3
551+ float32x4_t q4_dmin_1 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) q4_ptr[b].dmin + 4 )); // dmin 4..7
552+ float32x4_t sb_min_0 = vmulq_f32 (q4_dmin_0, q8_d);
553+ float32x4_t sb_min_1 = vmulq_f32 (q4_dmin_1, q8_d);
554+
555+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
556+ int32x4_t bias_acc[2 ] = {vdupq_n_s32 (0 ), vdupq_n_s32 (0 )};
557+ // 2 sb each iteration
558+ int32x4_t acc_lo[col_pairs];
559+ int32x4_t acc_hi[col_pairs];
560+
561+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
562+ const int16x8_t bsums = vpaddq_s16 (vld1q_s16 (q8_ptr[b].bsums ), vld1q_s16 (q8_ptr[b].bsums + 8 ));
563+ int16_t bsums_arr[8 ];
564+ vst1q_s16 (bsums_arr, bsums);
565+ for (int sb = 0 ; sb < QK_K / 64 ; sb++) {
566+ for (int i = 0 ; i < col_pairs; i++) {
567+ acc_lo[i] = vdupq_n_s32 (0 );
568+ acc_hi[i] = vdupq_n_s32 (0 );
569+ }
570+ // Need scales for the low and high nibbles
571+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
572+ int16x8_t q4sb_mins[2 ]; // int16 as its needed for bias_acc later
573+ int16x8_t q4sb_scales[2 ];
574+ for (int i = 0 ; i < 2 ; i++) {
575+ int8_t aux_q4sb[8 ];
576+ const int offset = sb * 24 + i * 12 ;
577+ decode_q4_Kx8_scales_mins (&q4_ptr[b].scales [offset], &q4sb_mins[i], aux_q4sb);
578+ q4sb_scales[i] = vmovl_s8 (vld1_s8 (aux_q4sb));
579+ }
580+
581+ const uint8_t *q4_base = q4_ptr[b].qs + sb * QK_K;
582+
583+ // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
584+ // but still need the qs to use the low and hi bits from q4
585+ const int8_t *q8_base = q8_ptr[b].qs + sb * 64 ;
586+ int8x16_t q8_qs[8 ];
587+ for (int i = 0 ; i < 8 ; i++) {
588+ q8_qs[i] = (int8x16_t ) vld1q_dup_s64 ((const int64_t *)(q8_base + i * 8 ));
589+ }
590+
591+
592+ // Q4s columns iterated in pairs (01, 23, 45, 67)
593+ for (int cp = 0 ; cp < col_pairs; cp++) {
594+ uint8x16_t q4_qs_cp_0 = vld1q_u8 (q4_base + 16 * cp);
595+ uint8x16_t q4_qs_cp_1 = vld1q_u8 (q4_base + 16 * cp + 64 );
596+ uint8x16_t q4_qs_cp_2 = vld1q_u8 (q4_base + 16 * cp + 128 );
597+ uint8x16_t q4_qs_cp_3 = vld1q_u8 (q4_base + 16 * cp + 192 );
598+
599+ acc_lo[cp] = vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_0, m4b)), q8_qs[0 ]); // 0 .. 7
600+ acc_lo[cp] = vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_1, m4b)), q8_qs[1 ]); // 8 ..15
601+ acc_lo[cp] = vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_2, m4b)), q8_qs[2 ]); // 16..23
602+ acc_lo[cp] = vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_3, m4b)), q8_qs[3 ]); // 24..31
603+
604+ acc_hi[cp] = vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_0, 4 )), q8_qs[4 ]); // 32..39
605+ acc_hi[cp] = vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_1, 4 )), q8_qs[5 ]); // 40..47
606+ acc_hi[cp] = vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_2, 4 )), q8_qs[6 ]); // 48..55
607+ acc_hi[cp] = vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_3, 4 )), q8_qs[7 ]); // 56..63
608+ }
609+
610+
611+ // Iterates over a pair of column pairs (4 columns) to use a single 128 register
612+ // p = 0 -> 0123 p2 -> 4567
613+ for (int i = 0 , p = 0 ; p < col_pairs; i++, p += 2 ) {
614+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16 (q4sb_scales[0 ]) : vget_high_s16 (q4sb_scales[0 ]);
615+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16 (q4sb_scales[1 ]) : vget_high_s16 (q4sb_scales[1 ]);
616+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
617+
618+ // 0123 or 4567
619+ // TODO: Single superblock mul at the end of the superblock
620+ float32x4_t sumf_0 = vcvtq_f32_s32 (vmulq_s32 (vmovl_s16 (group_scales_lo), vpaddq_s32 (acc_lo[p], acc_lo[p + 1 ])));
621+ acc_f32[i] = vfmaq_f32 (acc_f32[i], sb_scale, sumf_0);
622+
623+ float32x4_t sumf_1 = vcvtq_f32_s32 (vmulq_s32 (vmovl_s16 (group_scales_hi), vpaddq_s32 (acc_hi[p], acc_hi[p + 1 ])));
624+ acc_f32[i] = vfmaq_f32 (acc_f32[i], sb_scale, sumf_1);
625+ }
626+
627+ // Multiply Acc bsum + mins
628+ // Each pair of subblocks share the same bsums
629+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
630+ int16x4_t bsums_vec_lo = vdup_n_s16 (bsums_arr[2 * sb + 0 ]);
631+ int16x4_t bsums_vec_hi = vdup_n_s16 (bsums_arr[2 * sb + 1 ]);
632+
633+ // cols 0-3 bias
634+ bias_acc[0 ] =
635+ vmlal_s16 (bias_acc[0 ], bsums_vec_lo, vget_low_s16 (q4sb_mins[0 ]));
636+ bias_acc[1 ] = vmlal_s16 (bias_acc[1 ], bsums_vec_lo,
637+ vget_high_s16 (q4sb_mins[0 ]));
638+
639+ // cols 4-7 bias
640+ bias_acc[0 ] =
641+ vmlal_s16 (bias_acc[0 ], bsums_vec_hi, vget_low_s16 (q4sb_mins[1 ]));
642+ bias_acc[1 ] = vmlal_s16 (bias_acc[1 ], bsums_vec_hi,
643+ vget_high_s16 (q4sb_mins[1 ]));
644+ } // for sb
645+
646+ acc_f32[0 ] = vmlsq_f32 (acc_f32[0 ], vcvtq_f32_s32 (bias_acc[0 ]), sb_min_0);
647+ acc_f32[1 ] = vmlsq_f32 (acc_f32[1 ], vcvtq_f32_s32 (bias_acc[1 ]), sb_min_1);
648+ } // for b
649+
650+ int base = x * ncols_interleaved;
651+ vst1q_f32 (s + base, acc_f32[0 ]);
652+ vst1q_f32 (s + base + 4 , acc_f32[1 ]);
653+ } // for x
654+ return ;
655+ #endif
656+ ggml_gemv_q4_K_8x8_q8_K_generic (n, s, bs, vx, vy, nr, nc);
657+ }
658+
659+
501660void ggml_gemm_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
502661 const int qk = QK8_0;
503662 const int nb = n / qk;
0 commit comments