@@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
55735573
55745574    uint32_t  utmp [4 ];
55755575
5576- #ifdef  __ARM_NEON 
5576+ #ifdef  __ARM_FEATURE_SVE 
5577+     float  sumf  =  0 ;
5578+     for  (int  i  =  0 ; i  <  nb ; ++ i ) {
5579+ 
5580+         const  float  d  =  y [i ].d  *  GGML_FP16_TO_FP32 (x [i ].d );
5581+         const  float  dmin  =  y [i ].d  *  GGML_FP16_TO_FP32 (x [i ].dmin );
5582+ 
5583+         const  int16x8_t  q8sums  =  vpaddq_s16 (vld1q_s16 (y [i ].bsums ), vld1q_s16 (y [i ].bsums  +  8 ));
5584+ 
5585+         memcpy (utmp , x [i ].scales , K_SCALE_SIZE );
5586+ 
5587+         uint32x2_t  mins8  =  { 0  };
5588+         mins8  =  vset_lane_u32 (utmp [1 ] &  kmask1 , mins8 , 0 );
5589+         mins8  =  vset_lane_u32 (((utmp [2 ] >> 4 ) &  kmask2 ) | (((utmp [1 ] >> 6 ) &  kmask3 ) << 4 ), mins8 , 1 );
5590+ 
5591+         utmp [1 ] =  (utmp [2 ] &  kmask2 ) | (((utmp [0 ] >> 6 ) &  kmask3 ) << 4 );
5592+         utmp [0 ] &= kmask1 ;
5593+ 
5594+         const  int16x8_t  mins  =  vreinterpretq_s16_u16 (vmovl_u8 (vreinterpret_u8_u32 (mins8 )));
5595+         const  int32x4_t  prod  =  vaddq_s32 (vmull_s16 (vget_low_s16  (q8sums ), vget_low_s16  (mins )),
5596+                                          vmull_s16 (vget_high_s16 (q8sums ), vget_high_s16 (mins )));
5597+         sumf  -=  dmin  *  vaddvq_s32 (prod );
5598+ 
5599+         const  uint8_t  *  scales  =  (const  uint8_t  * )utmp ;
5600+ 
5601+         const  uint8_t  *  restrict q4  =  x [i ].qs ;
5602+         const  int8_t   *  restrict q8  =  y [i ].qs ;
5603+ 
5604+         const  int  vector_length  =  ggml_cpu_get_sve_cnt ()* 8 ;
5605+         const  svuint8_t  m4b  =  svdup_n_u8 (0xf );
5606+         const  svint32_t  mzero  =  svdup_n_s32 (0 );
5607+         svint32_t  sumi1  =  svdup_n_s32 (0 );
5608+         svint32_t  sumi1_1  =  svdup_n_s32 (0 );
5609+         svint32_t  sumi1_2  =  svdup_n_s32 (0 );
5610+         svint32_t  sumi2  =  svdup_n_s32 (0 );
5611+         svint32_t  sumi2_1  =  svdup_n_s32 (0 );
5612+         svint32_t  sumi2_2  =  svdup_n_s32 (0 );
5613+         switch  (vector_length ) {
5614+             case  128 :
5615+                 {
5616+                     for  (int  j  =  0 ; j  <  QK_K /64 ; ++ j ) {
5617+                         svint8_t  q4bytes  =  svreinterpret_s8_u8 (svand_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 ), m4b ));
5618+                         svint8_t  q8bytes  =  svld1_s8 (svptrue_b8 (), q8 ); q8  +=  16 ;
5619+                         sumi1_1  =  svmla_n_s32_x (svptrue_b32 (), sumi1_1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5620+                         q4bytes  =  svreinterpret_s8_u8 (svand_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 + 16 ), m4b ));
5621+                         q8bytes  =  svld1_s8 (svptrue_b8 (), q8 ); q8  +=  16 ;
5622+                         sumi1_2  =  svmla_n_s32_x (svptrue_b32 (), sumi1_2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5623+ 
5624+                         q4bytes  =  svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 ), 4 ));
5625+                         q8bytes  =  svld1_s8 (svptrue_b8 (), q8 ); q8  +=  16 ;
5626+                         sumi2_1  =  svmla_n_s32_x (svptrue_b32 (), sumi2_1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5627+                         q4bytes  =  svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_b8 (), svld1_u8 (svptrue_b8 (), q4 + 16 ), 4 ));
5628+                         q8bytes  =  svld1_s8 (svptrue_b8 (), q8 ); q8  +=  16 ;
5629+                         sumi2_2  =  svmla_n_s32_x (svptrue_b32 (), sumi2_2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5630+                         q4  +=  32 ;
5631+                     }
5632+                     sumi1  =  svadd_s32_x (svptrue_b32 (), sumi1_1 , sumi1_2 );
5633+                     sumi2  =  svadd_s32_x (svptrue_b32 (), sumi2_1 , sumi2_2 );
5634+                     sumf  +=  d  *  (svaddv_s32 (svptrue_b32 (), svadd_s32_x (svptrue_b32 (), sumi1 , sumi2 )));
5635+                 } break ;
5636+             case  256 :
5637+             case  512 :
5638+                 {
5639+                     for  (int  j  =  0 ; j  <  QK_K /64 ; ++ j ) {
5640+                         const  svuint8_t  q4bits   =  svld1_u8 (svptrue_pat_b8 (SV_VL32 ), q4 ); q4  +=  32 ;
5641+                         svint8_t  q4bytes  =  svreinterpret_s8_u8 (svand_u8_x (svptrue_pat_b8 (SV_VL32 ), q4bits , m4b ));
5642+                         svint8_t  q8bytes  =  svld1_s8 (svptrue_pat_b8 (SV_VL32 ), q8 ); q8  +=  32 ;
5643+                         sumi1  =  svmla_n_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi1 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 0 ]);
5644+ 
5645+                         q4bytes  =  svreinterpret_s8_u8 (svlsr_n_u8_x (svptrue_pat_b8 (SV_VL32 ), q4bits , 4 ));
5646+                         q8bytes  =  svld1_s8 (svptrue_pat_b8 (SV_VL32 ), q8 ); q8  +=  32 ;
5647+                         sumi2  =  svmla_n_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi2 , svdot_s32 (mzero , q4bytes , q8bytes ), scales [2 * j + 1 ]);
5648+                     }
5649+                     sumf  +=  d  *  (svaddv_s32 (svptrue_pat_b32 (SV_VL8 ), svadd_s32_x (svptrue_pat_b32 (SV_VL8 ), sumi1 , sumi2 )));
5650+                 } break ;
5651+             default :
5652+                 assert (false &&  "Unsupported vector length" );
5653+                 break ;
5654+         }
5655+     }
5656+     * s  =  sumf ;
5657+ #elif  __ARM_NEON 
55775658    const  uint8x16_t  m4b  =  vdupq_n_u8 (0xf );
55785659    const  int32x4_t  mzero  =  vdupq_n_s32 (0 );
55795660
0 commit comments