Skip to content

Commit 0b1fec6

Browse files
committed
cpu: arm: Implemented REPACK gemv for Q4_K
Signed-off-by: Alberto Cabrera <[email protected]>
1 parent 28e30c2 commit 0b1fec6

File tree

2 files changed

+159
-1
lines changed

2 files changed

+159
-1
lines changed

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
5252
// repack.cpp
5353
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
54-
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
5554
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
5655
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
5756
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0

ggml/src/ggml-cpu/arch/arm/repack.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,165 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
498498
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
499499
}
500500

501+
void ggml_gemv_q4_K_8x8_q8_K(int n,
502+
float * GGML_RESTRICT s,
503+
size_t bs,
504+
const void * GGML_RESTRICT vx,
505+
const void * GGML_RESTRICT vy,
506+
int nr,
507+
int nc) {
508+
constexpr int qk = QK_K;
509+
const int nb = n / qk;
510+
511+
constexpr int ncols_interleaved = 8;
512+
constexpr int blocklen = 8;
513+
514+
assert(n % qk == 0);
515+
assert(nr % 4 == 0);
516+
assert(nc % ncols_interleaved == 0);
517+
518+
UNUSED(s);
519+
UNUSED(bs);
520+
UNUSED(vx);
521+
UNUSED(vy);
522+
UNUSED(nr);
523+
UNUSED(nc);
524+
UNUSED(nb);
525+
UNUSED(ncols_interleaved);
526+
UNUSED(blocklen);
527+
528+
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
529+
constexpr int col_pairs = ncols_interleaved / 2;
530+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
531+
532+
// 1x8 tile = 2 x 4
533+
float32x4_t acc_f32[ncols_interleaved / 4];
534+
535+
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
536+
537+
for (int x = 0; x < nc / ncols_interleaved; x++) {
538+
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
539+
540+
for (int i = 0; i < ncols_interleaved / 4; i++) {
541+
acc_f32[i] = vdupq_n_f32(0);
542+
}
543+
544+
for (int b = 0; b < nb; b++) {
545+
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *)q4_ptr[b].d)); // d0 d1 d2 d3
546+
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *)q4_ptr[b].d + 4)); // d4 d5 d6 d7
547+
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
548+
float32x4_t sb_scale_0 = vmulq_f32 (q4_d_0, q8_d);
549+
float32x4_t sb_scale_1 = vmulq_f32 (q4_d_1, q8_d);
550+
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
551+
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
552+
float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
553+
float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
554+
555+
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
556+
int32x4_t bias_acc[2] = {vdupq_n_s32(0), vdupq_n_s32(0)};
557+
// 2 sb each iteration
558+
int32x4_t acc_lo[col_pairs];
559+
int32x4_t acc_hi[col_pairs];
560+
561+
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
562+
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
563+
int16_t bsums_arr[8];
564+
vst1q_s16(bsums_arr, bsums);
565+
for (int sb = 0; sb < QK_K / 64; sb++) {
566+
for (int i = 0; i < col_pairs; i++) {
567+
acc_lo[i] = vdupq_n_s32(0);
568+
acc_hi[i] = vdupq_n_s32(0);
569+
}
570+
// Need scales for the low and high nibbles
571+
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
572+
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
573+
int16x8_t q4sb_scales[2];
574+
for (int i = 0; i < 2; i++) {
575+
int8_t aux_q4sb[8];
576+
const int offset = sb * 24 + i * 12;
577+
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
578+
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
579+
}
580+
581+
const uint8_t *q4_base = q4_ptr[b].qs + sb * QK_K;
582+
583+
// Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
584+
// but still need the qs to use the low and hi bits from q4
585+
const int8_t *q8_base = q8_ptr[b].qs + sb * 64;
586+
int8x16_t q8_qs[8];
587+
for (int i = 0; i < 8; i++) {
588+
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *)(q8_base + i * 8));
589+
}
590+
591+
592+
// Q4s columns iterated in pairs (01, 23, 45, 67)
593+
for (int cp = 0; cp < col_pairs; cp++) {
594+
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
595+
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
596+
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
597+
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
598+
599+
acc_lo[cp] = vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
600+
acc_lo[cp] = vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
601+
acc_lo[cp] = vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
602+
acc_lo[cp] = vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
603+
604+
acc_hi[cp] = vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
605+
acc_hi[cp] = vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
606+
acc_hi[cp] = vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
607+
acc_hi[cp] = vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
608+
}
609+
610+
611+
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
612+
// p = 0 -> 0123 p2 -> 4567
613+
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
614+
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
615+
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
616+
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
617+
618+
// 0123 or 4567
619+
// TODO: Single superblock mul at the end of the superblock
620+
float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
621+
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
622+
623+
float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
624+
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
625+
}
626+
627+
// Multiply Acc bsum + mins
628+
// Each pair of subblocks share the same bsums
629+
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
630+
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
631+
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
632+
633+
// cols 0-3 bias
634+
bias_acc[0] =
635+
vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
636+
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo,
637+
vget_high_s16(q4sb_mins[0]));
638+
639+
// cols 4-7 bias
640+
bias_acc[0] =
641+
vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
642+
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi,
643+
vget_high_s16(q4sb_mins[1]));
644+
} // for sb
645+
646+
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
647+
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
648+
} // for b
649+
650+
int base = x * ncols_interleaved;
651+
vst1q_f32(s + base, acc_f32[0]);
652+
vst1q_f32(s + base + 4, acc_f32[1]);
653+
} // for x
654+
return;
655+
#endif
656+
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
657+
}
658+
659+
501660
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
502661
const int qk = QK8_0;
503662
const int nb = n / qk;

0 commit comments

Comments
 (0)