@@ -525,9 +525,9 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
525525 UNUSED (ncols_interleaved);
526526 UNUSED (blocklen);
527527
528- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
529- constexpr int col_pairs = ncols_interleaved / 2 ;
530- const uint8x16_t m4b = vdupq_n_u8 (0x0f );
528+ #if !((defined(_MSC_VER)) && !defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
529+ constexpr int col_pairs = ncols_interleaved / 2 ;
530+ const uint8x16_t m4b = vdupq_n_u8 (0x0f );
531531
532532 // 1x8 tile = 2 x 4
533533 float32x4_t acc_f32[ncols_interleaved / 4 ];
@@ -542,25 +542,25 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
542542 }
543543
544544 for (int b = 0 ; b < nb; b++) {
545- float32x4_t q4_d_0 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *)q4_ptr[b].d )); // d0 d1 d2 d3
546- float32x4_t q4_d_1 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *)q4_ptr[b].d + 4 )); // d4 d5 d6 d7
547- float32x4_t q8_d = vdupq_n_f32 (q8_ptr[b].d );
548- float32x4_t sb_scale_0 = vmulq_f32 (q4_d_0, q8_d);
549- float32x4_t sb_scale_1 = vmulq_f32 (q4_d_1, q8_d);
550- float32x4_t q4_dmin_0 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) q4_ptr[b].dmin )); // dmin 0..3
551- float32x4_t q4_dmin_1 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) q4_ptr[b].dmin + 4 )); // dmin 4..7
552- float32x4_t sb_min_0 = vmulq_f32 (q4_dmin_0, q8_d);
553- float32x4_t sb_min_1 = vmulq_f32 (q4_dmin_1, q8_d);
545+ float32x4_t q4_d_0 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) q4_ptr[b].d )); // d0 d1 d2 d3
546+ float32x4_t q4_d_1 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) q4_ptr[b].d + 4 )); // d4 d5 d6 d7
547+ float32x4_t q8_d = vdupq_n_f32 (q8_ptr[b].d );
548+ float32x4_t sb_scale_0 = vmulq_f32 (q4_d_0, q8_d);
549+ float32x4_t sb_scale_1 = vmulq_f32 (q4_d_1, q8_d);
550+ float32x4_t q4_dmin_0 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) q4_ptr[b].dmin )); // dmin 0..3
551+ float32x4_t q4_dmin_1 = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) q4_ptr[b].dmin + 4 )); // dmin 4..7
552+ float32x4_t sb_min_0 = vmulq_f32 (q4_dmin_0, q8_d);
553+ float32x4_t sb_min_1 = vmulq_f32 (q4_dmin_1, q8_d);
554554
555555 // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
556- int32x4_t bias_acc[2 ] = {vdupq_n_s32 (0 ), vdupq_n_s32 (0 )};
556+ int32x4_t bias_acc[2 ] = { vdupq_n_s32 (0 ), vdupq_n_s32 (0 ) };
557557 // 2 sb each iteration
558558 int32x4_t acc_lo[col_pairs];
559559 int32x4_t acc_hi[col_pairs];
560560
561561 // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
562562 const int16x8_t bsums = vpaddq_s16 (vld1q_s16 (q8_ptr[b].bsums ), vld1q_s16 (q8_ptr[b].bsums + 8 ));
563- int16_t bsums_arr[8 ];
563+ int16_t bsums_arr[8 ];
564564 vst1q_s16 (bsums_arr, bsums);
565565 for (int sb = 0 ; sb < QK_K / 64 ; sb++) {
566566 for (int i = 0 ; i < col_pairs; i++) {
@@ -578,49 +578,57 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
578578 q4sb_scales[i] = vmovl_s8 (vld1_s8 (aux_q4sb));
579579 }
580580
581- const uint8_t *q4_base = q4_ptr[b].qs + sb * QK_K;
581+ const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
582582
583583 // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
584584 // but still need the qs to use the low and hi bits from q4
585- const int8_t *q8_base = q8_ptr[b].qs + sb * 64 ;
586- int8x16_t q8_qs[8 ];
585+ const int8_t * q8_base = q8_ptr[b].qs + sb * 64 ;
586+ int8x16_t q8_qs[8 ];
587587 for (int i = 0 ; i < 8 ; i++) {
588- q8_qs[i] = (int8x16_t ) vld1q_dup_s64 ((const int64_t *)(q8_base + i * 8 ));
588+ q8_qs[i] = (int8x16_t ) vld1q_dup_s64 ((const int64_t *) (q8_base + i * 8 ));
589589 }
590590
591-
592591 // Q4s columns iterated in pairs (01, 23, 45, 67)
593592 for (int cp = 0 ; cp < col_pairs; cp++) {
594593 uint8x16_t q4_qs_cp_0 = vld1q_u8 (q4_base + 16 * cp);
595594 uint8x16_t q4_qs_cp_1 = vld1q_u8 (q4_base + 16 * cp + 64 );
596595 uint8x16_t q4_qs_cp_2 = vld1q_u8 (q4_base + 16 * cp + 128 );
597596 uint8x16_t q4_qs_cp_3 = vld1q_u8 (q4_base + 16 * cp + 192 );
598597
599- acc_lo[cp] = ggml_vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_0, m4b)), q8_qs[0 ]); // 0 .. 7
600- acc_lo[cp] = ggml_vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_1, m4b)), q8_qs[1 ]); // 8 ..15
601- acc_lo[cp] = ggml_vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_2, m4b)), q8_qs[2 ]); // 16..23
602- acc_lo[cp] = ggml_vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_3, m4b)), q8_qs[3 ]); // 24..31
603-
604- acc_hi[cp] = ggml_vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_0, 4 )), q8_qs[4 ]); // 32..39
605- acc_hi[cp] = ggml_vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_1, 4 )), q8_qs[5 ]); // 40..47
606- acc_hi[cp] = ggml_vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_2, 4 )), q8_qs[6 ]); // 48..55
607- acc_hi[cp] = ggml_vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_3, 4 )), q8_qs[7 ]); // 56..63
598+ acc_lo[cp] =
599+ ggml_vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_0, m4b)), q8_qs[0 ]); // 0 .. 7
600+ acc_lo[cp] =
601+ ggml_vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_1, m4b)), q8_qs[1 ]); // 8 ..15
602+ acc_lo[cp] =
603+ ggml_vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_2, m4b)), q8_qs[2 ]); // 16..23
604+ acc_lo[cp] =
605+ ggml_vdotq_s32 (acc_lo[cp], vreinterpretq_s8_u8 (vandq_u8 (q4_qs_cp_3, m4b)), q8_qs[3 ]); // 24..31
606+
607+ acc_hi[cp] =
608+ ggml_vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_0, 4 )), q8_qs[4 ]); // 32..39
609+ acc_hi[cp] =
610+ ggml_vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_1, 4 )), q8_qs[5 ]); // 40..47
611+ acc_hi[cp] =
612+ ggml_vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_2, 4 )), q8_qs[6 ]); // 48..55
613+ acc_hi[cp] =
614+ ggml_vdotq_s32 (acc_hi[cp], vreinterpretq_s8_u8 (vshrq_n_u8 (q4_qs_cp_3, 4 )), q8_qs[7 ]); // 56..63
608615 }
609616
610-
611617 // Iterates over a pair of column pairs (4 columns) to use a single 128 register
612618 // p = 0 -> 0123 p2 -> 4567
613619 for (int i = 0 , p = 0 ; p < col_pairs; i++, p += 2 ) {
614- int16x4_t group_scales_lo = p == 0 ? vget_low_s16 (q4sb_scales[0 ]) : vget_high_s16 (q4sb_scales[0 ]);
615- int16x4_t group_scales_hi = p == 0 ? vget_low_s16 (q4sb_scales[1 ]) : vget_high_s16 (q4sb_scales[1 ]);
616- float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
620+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16 (q4sb_scales[0 ]) : vget_high_s16 (q4sb_scales[0 ]);
621+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16 (q4sb_scales[1 ]) : vget_high_s16 (q4sb_scales[1 ]);
622+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
617623
618624 // 0123 or 4567
619625 // TODO: Single superblock mul at the end of the superblock
620- float32x4_t sumf_0 = vcvtq_f32_s32 (vmulq_s32 (vmovl_s16 (group_scales_lo), vpaddq_s32 (acc_lo[p], acc_lo[p + 1 ])));
626+ float32x4_t sumf_0 =
627+ vcvtq_f32_s32 (vmulq_s32 (vmovl_s16 (group_scales_lo), vpaddq_s32 (acc_lo[p], acc_lo[p + 1 ])));
621628 acc_f32[i] = vfmaq_f32 (acc_f32[i], sb_scale, sumf_0);
622629
623- float32x4_t sumf_1 = vcvtq_f32_s32 (vmulq_s32 (vmovl_s16 (group_scales_hi), vpaddq_s32 (acc_hi[p], acc_hi[p + 1 ])));
630+ float32x4_t sumf_1 =
631+ vcvtq_f32_s32 (vmulq_s32 (vmovl_s16 (group_scales_hi), vpaddq_s32 (acc_hi[p], acc_hi[p + 1 ])));
624632 acc_f32[i] = vfmaq_f32 (acc_f32[i], sb_scale, sumf_1);
625633 }
626634
@@ -631,16 +639,12 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
631639 int16x4_t bsums_vec_hi = vdup_n_s16 (bsums_arr[2 * sb + 1 ]);
632640
633641 // cols 0-3 bias
634- bias_acc[0 ] =
635- vmlal_s16 (bias_acc[0 ], bsums_vec_lo, vget_low_s16 (q4sb_mins[0 ]));
636- bias_acc[0 ] =
637- vmlal_s16 (bias_acc[0 ], bsums_vec_hi, vget_low_s16 (q4sb_mins[1 ]));
642+ bias_acc[0 ] = vmlal_s16 (bias_acc[0 ], bsums_vec_lo, vget_low_s16 (q4sb_mins[0 ]));
643+ bias_acc[0 ] = vmlal_s16 (bias_acc[0 ], bsums_vec_hi, vget_low_s16 (q4sb_mins[1 ]));
638644
639645 // cols 4-7 bias
640- bias_acc[1 ] = vmlal_s16 (bias_acc[1 ], bsums_vec_lo,
641- vget_high_s16 (q4sb_mins[0 ]));
642- bias_acc[1 ] = vmlal_s16 (bias_acc[1 ], bsums_vec_hi,
643- vget_high_s16 (q4sb_mins[1 ]));
646+ bias_acc[1 ] = vmlal_s16 (bias_acc[1 ], bsums_vec_lo, vget_high_s16 (q4sb_mins[0 ]));
647+ bias_acc[1 ] = vmlal_s16 (bias_acc[1 ], bsums_vec_hi, vget_high_s16 (q4sb_mins[1 ]));
644648 } // for sb
645649
646650 acc_f32[0 ] = vmlsq_f32 (acc_f32[0 ], vcvtq_f32_s32 (bias_acc[0 ]), sb_min_0);
@@ -652,7 +656,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
652656 vst1q_f32 (s + base + 4 , acc_f32[1 ]);
653657 } // for x
654658 return ;
655- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
659+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
656660 ggml_gemv_q4_K_8x8_q8_K_generic (n, s, bs, vx, vy, nr, nc);
657661}
658662
@@ -2101,9 +2105,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
21012105 UNUSED (ncols_interleaved);
21022106 UNUSED (blocklen);
21032107
2104- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
2105- constexpr int q8_k_blocklen = 4 ;
2106- const uint8x16_t m4b = vdupq_n_u8 (0x0f );
2108+ #if !((defined(_MSC_VER)) && !defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && \
2109+ defined (__ARM_FEATURE_MATMUL_INT8)
2110+ constexpr int q8_k_blocklen = 4 ;
2111+ const uint8x16_t m4b = vdupq_n_u8 (0x0f );
21072112
21082113 // 8 accumulators: 2 row pairs × 4 col pairs
21092114 float32x4_t acc_f32[blocklen];
@@ -2131,9 +2136,9 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
21312136 vst1q_s16 (bsums_arr[q8_row], bsums[q8_row]);
21322137 }
21332138
2134- int32x4_t sb_acc[4 ]; // Aux accumulators to store subblock (partial) results
2135- int32x4_t acc[8 ]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
2136- int32x4_t bias_acc[8 ]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
2139+ int32x4_t sb_acc[4 ]; // Aux accumulators to store subblock (partial) results
2140+ int32x4_t acc[8 ]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
2141+ int32x4_t bias_acc[8 ]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
21372142 for (int i = 0 ; i < 8 ; i++) {
21382143 acc[i] = vdupq_n_s32 (0 );
21392144 bias_acc[i] = vdupq_n_s32 (0 );
@@ -2150,16 +2155,16 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
21502155 }
21512156
21522157 // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
2153- const int8_t *q8_base = q8_ptr[b].qs + sb * 256 ;
2158+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256 ;
21542159
21552160 int8x16_t q8_qs_01[8 ];
21562161 int8x16_t q8_qs_23[8 ];
21572162
21582163 // Load 32-byte per row pair, 1 subblock each time
21592164 for (int i = 0 ; i < 8 ; i++) {
21602165 const int offset = i * 32 ; // 16 for row 01, 16 for row 23
2161- q8_qs_01[i] = vld1q_s8 (q8_base + offset);
2162- q8_qs_23[i] = vld1q_s8 (q8_base + offset + 16 );
2166+ q8_qs_01[i] = vld1q_s8 (q8_base + offset);
2167+ q8_qs_23[i] = vld1q_s8 (q8_base + offset + 16 );
21632168 }
21642169
21652170 const int8x16_t q8s[2 ][8 ] = {
@@ -2234,10 +2239,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
22342239 vmlal_s16 (bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16 (q4sb_mins[0 ]));
22352240 bias_acc[2 * q8_row] =
22362241 vmlal_s16 (bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16 (q4sb_mins[1 ]));
2237- bias_acc[2 * q8_row + 1 ] = vmlal_s16 (bias_acc[ 2 * q8_row + 1 ], bsums_vec_lo,
2238- vget_high_s16 (q4sb_mins[0 ]));
2239- bias_acc[2 * q8_row + 1 ] = vmlal_s16 (bias_acc[ 2 * q8_row + 1 ], bsums_vec_hi,
2240- vget_high_s16 (q4sb_mins[1 ]));
2242+ bias_acc[2 * q8_row + 1 ] =
2243+ vmlal_s16 (bias_acc[ 2 * q8_row + 1 ], bsums_vec_lo, vget_high_s16 (q4sb_mins[0 ]));
2244+ bias_acc[2 * q8_row + 1 ] =
2245+ vmlal_s16 (bias_acc[ 2 * q8_row + 1 ], bsums_vec_hi, vget_high_s16 (q4sb_mins[1 ]));
22412246 }
22422247 } // for sb
22432248
@@ -2259,11 +2264,11 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
22592264
22602265 for (int i = 0 ; i < q8_k_blocklen; i++) {
22612266 for (int j = 0 ; j < 2 ; j++) {
2262- float32x4_t q8_d = vdupq_n_f32 (q8_ptr[b].d [i]);
2263- float32x4_t q4_dmin = vcvt_f32_f16 (vld1_f16 ((const __fp16 *)(q4_ptr[b].dmin + j * 4 )));
2264- const float32x4_t dmins = vmulq_f32 (q4_dmin, q8_d);
2267+ float32x4_t q8_d = vdupq_n_f32 (q8_ptr[b].d [i]);
2268+ float32x4_t q4_dmin = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) (q4_ptr[b].dmin + j * 4 )));
2269+ const float32x4_t dmins = vmulq_f32 (q4_dmin, q8_d);
22652270
2266- float32x4_t q4_d = vcvt_f32_f16 (vld1_f16 ((const __fp16 *)(q4_ptr[b].d + j * 4 )));
2271+ float32x4_t q4_d = vcvt_f32_f16 (vld1_f16 ((const __fp16 *) (q4_ptr[b].d + j * 4 )));
22672272 const float32x4_t scale = vmulq_f32 (q4_d, q8_d);
22682273
22692274 acc_f32[2 * i + j] = vmlsq_f32 (acc_f32[2 * i + j], vcvtq_f32_s32 (bias_acc[2 * i + j]), dmins);
0 commit comments