@@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
55735573
55745574 uint32_t utmp [4 ];
55755575
5576- #ifdef __ARM_NEON
5576+ #ifdef __ARM_FEATURE_SVE
5577+ float sumf = 0 ;
5578+ for (int i = 0 ; i < nb ; ++ i ) {
5579+
5580+ const float d = y [i ].d * GGML_FP16_TO_FP32 (x [i ].d );
5581+ const float dmin = y [i ].d * GGML_FP16_TO_FP32 (x [i ].dmin );
5582+
5583+ const int16x8_t q8sums = vpaddq_s16 (vld1q_s16 (y [i ].bsums ), vld1q_s16 (y [i ].bsums + 8 ));
5584+
5585+ memcpy (utmp , x [i ].scales , K_SCALE_SIZE );
5586+
5587+ uint32x2_t mins8 = { 0 };
5588+ mins8 = vset_lane_u32 (utmp [1 ] & kmask1 , mins8 , 0 );
5589+ mins8 = vset_lane_u32 (((utmp [2 ] >> 4 ) & kmask2 ) | (((utmp [1 ] >> 6 ) & kmask3 ) << 4 ), mins8 , 1 );
5590+
5591+ utmp [1 ] = (utmp [2 ] & kmask2 ) | (((utmp [0 ] >> 6 ) & kmask3 ) << 4 );
5592+ utmp [0 ] &= kmask1 ;
5593+
5594+ const int16x8_t mins = vreinterpretq_s16_u16 (vmovl_u8 (vreinterpret_u8_u32 (mins8 )));
5595+ const int32x4_t prod = vaddq_s32 (vmull_s16 (vget_low_s16 (q8sums ), vget_low_s16 (mins )),
5596+ vmull_s16 (vget_high_s16 (q8sums ), vget_high_s16 (mins )));
5597+ sumf -= dmin * vaddvq_s32 (prod );
5598+
5599+ const uint8_t * scales = (const uint8_t * )utmp ;
5600+
5601+ const uint8_t * restrict q4 = x [i ].qs ;
5602+ const int8_t * restrict q8 = y [i ].qs ;
5603+
5604+ const int vector_length = ggml_cpu_get_sve_cnt ()* 8 ;
5605+ const svuint8_t m4b = svdup_n_u8 (0xf );
5606+ const svint32_t mzero = svdup_n_s32 (0 );
5607+ svint32_t sumi1 = svdup_n_s32 (0 );
5608+ svint32_t sumi1_1 = svdup_n_s32 (0 );
5609+ svint32_t sumi1_2 = svdup_n_s32 (0 );
5610+ svint32_t sumi2 = svdup_n_s32 (0 );
5611+ svint32_t sumi2_1 = svdup_n_s32 (0 );
5612+ svint32_t sumi2_2 = svdup_n_s32 (0 );
5613+ switch (vector_length ) {
5614+ case 128 :
5615+ {
5616+ for (int j = 0 ; j < QK_K /64 ; ++ j ) {
5617+ svint8_t q4bytes = svreinterpret_s8_u8 (svand_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 ), m4b ));
5618+ svint8_t q8bytes = svld1_s8 (svptrue_b8 (), q8 ); q8 += 16 ;
5619+ sumi1_1 = svmla_n_s32_x (svptrue_b32 (), sumi1_1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5620+ q4bytes = svreinterpret_s8_u8 (svand_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 + 16 ), m4b ));
5621+ q8bytes = svld1_s8 (svptrue_b8 (), q8 ); q8 += 16 ;
5622+ sumi1_2 = svmla_n_s32_x (svptrue_b32 (), sumi1_2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5623+
5624+ q4bytes = svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 ), 4 ));
5625+ q8bytes = svld1_s8 (svptrue_b8 (), q8 ); q8 += 16 ;
5626+ sumi2_1 = svmla_n_s32_x (svptrue_b32 (), sumi2_1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5627+ q4bytes = svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 + 16 ), 4 ));
5628+ q8bytes = svld1_s8 (svptrue_b8 (), q8 ); q8 += 16 ;
5629+ sumi2_2 = svmla_n_s32_x (svptrue_b32 (), sumi2_2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5630+ q4 += 32 ;
5631+ }
5632+ sumi1 = svadd_s32_x (svptrue_b32 (), sumi1_1 , sumi1_2 );
5633+ sumi2 = svadd_s32_x (svptrue_b32 (), sumi2_1 , sumi2_2 );
5634+ sumf += d * (svaddv_s32 (svptrue_b32 (), svadd_s32_x (svptrue_b32 (), sumi1 , sumi2 )));
5635+ } break ;
5636+ case 256 :
5637+ case 512 :
5638+ {
5639+ for (int j = 0 ; j < QK_K /64 ; ++ j ) {
5640+ const svuint8_t q4bits = svld1_u8 (svptrue_pat_b8 (SV_VL32 ), q4 ); q4 += 32 ;
5641+ svint8_t q4bytes = svreinterpret_s8_u8 (svand_u8_x (svptrue_pat_b8 (SV_VL32 ), q4bits , m4b ));
5642+ svint8_t q8bytes = svld1_s8 (svptrue_pat_b8 (SV_VL32 ), q8 ); q8 += 32 ;
5643+ sumi1 = svmla_n_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5644+
5645+ q4bytes = svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_pat_b8 (SV_VL32 ), q4bits , 4 ));
5646+ q8bytes = svld1_s8 (svptrue_pat_b8 (SV_VL32 ), q8 ); q8 += 32 ;
5647+ sumi2 = svmla_n_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5648+ }
5649+ sumf += d * (svaddv_s32 (svptrue_pat_b32 (SV_VL8 ), svadd_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi1 , sumi2 )));
5650+ } break ;
5651+ default :
5652+ assert (false && "Unsupported vector length" );
5653+ break ;
5654+ }
5655+ }
5656+ * s = sumf ;
5657+ #elif __ARM_NEON
55775658 const uint8x16_t m4b = vdupq_n_u8 (0xf );
55785659 const int32x4_t mzero = vdupq_n_s32 (0 );
55795660
0 commit comments