@@ -187,6 +187,8 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
187187}
188188#endif 
189189
190+ static  const  int8_t  kvalues_iq4nl [16 ] =  {-127 , -104 , -83 , -65 , -49 , -35 , -22 , -10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 };
191+ 
190192static  void  quantize_q8_0_4x4 (const  float  *  restrict x , void  *  restrict vy , int64_t  k ) {
191193    assert (QK8_0  ==  32 );
192194    assert (k  % QK8_0  ==  0 );
@@ -996,6 +998,102 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
996998    }
997999}
9981000
1001+ void  ggml_gemv_iq4_nl_4x4_q8_0 (int  n , float  *  restrict s , size_t  bs , const  void  *  restrict vx , const  void  *  restrict vy , int  nr , int  nc ) {
1002+     const  int  qk  =  QK8_0 ;
1003+     const  int  nb  =  n  / qk ;
1004+     const  int  ncols_interleaved  =  4 ;
1005+     const  int  blocklen  =  4 ;
1006+ 
1007+     assert  (n  % qk  ==  0 );
1008+     assert  (nc  % ncols_interleaved  ==  0 );
1009+ 
1010+     UNUSED (s );
1011+     UNUSED (bs );
1012+     UNUSED (vx );
1013+     UNUSED (vy );
1014+     UNUSED (nr );
1015+     UNUSED (nc );
1016+     UNUSED (nb );
1017+     UNUSED (ncols_interleaved );
1018+     UNUSED (blocklen );
1019+ 
1020+ #if  ! ((defined(_MSC_VER )) &&  ! defined(__clang__ )) &&  defined(__aarch64__ ) &&  defined(__ARM_NEON )
1021+     if  (ggml_cpu_has_neon ()) {
1022+         const  int8x16_t  kvalues  =  vld1q_s8 (kvalues_iq4nl );
1023+         const  block_q8_0  *  a_ptr  =  (const  block_q8_0  * ) vy ;
1024+         float  *  res_ptr  =  s ;
1025+ 
1026+         for  (int  x  =  0 ; x  <  nc  / ncols_interleaved ; x ++ ) {
1027+             const  block_q4_0x4  *  b_ptr  =  (const  block_q4_0x4  * ) vx  +  (x  *  nb );
1028+ 
1029+             float32x4_t  sumf  =  vdupq_n_f32 (0 );
1030+             for  (int  l  =  0 ; l  <  nb ; l ++ ) {
1031+                 uint8x16_t  b_0  =  vld1q_u8 (b_ptr [l ].qs  +  0 );
1032+                 uint8x16_t  b_1  =  vld1q_u8 (b_ptr [l ].qs  +  16 );
1033+                 uint8x16_t  b_2  =  vld1q_u8 (b_ptr [l ].qs  +  32 );
1034+                 uint8x16_t  b_3  =  vld1q_u8 (b_ptr [l ].qs  +  48 );
1035+ 
1036+                 int8x16_t  b_0_hi  =  vqtbl1q_s8 (kvalues , b_0  >> 4 );
1037+                 int8x16_t  b_0_lo  =  vqtbl1q_s8 (kvalues , b_0  &  0x0F );
1038+                 int8x16_t  b_1_hi  =  vqtbl1q_s8 (kvalues , b_1  >> 4 );
1039+                 int8x16_t  b_1_lo  =  vqtbl1q_s8 (kvalues , b_1  &  0x0F );
1040+                 int8x16_t  b_2_hi  =  vqtbl1q_s8 (kvalues , b_2  >> 4 );
1041+                 int8x16_t  b_2_lo  =  vqtbl1q_s8 (kvalues , b_2  &  0x0F );
1042+                 int8x16_t  b_3_hi  =  vqtbl1q_s8 (kvalues , b_3  >> 4 );
1043+                 int8x16_t  b_3_lo  =  vqtbl1q_s8 (kvalues , b_3  &  0x0F );
1044+ 
1045+                 int8x16_t  a_0  =  vld1q_s8 (a_ptr [l ].qs  +  0 );
1046+                 int8x16_t  a_1  =  vld1q_s8 (a_ptr [l ].qs  +  16 );
1047+ 
1048+                 int32x4_t  sumi  =  vdupq_n_s32 (0 );
1049+                 sumi  =  vdotq_laneq_s32 (sumi , b_0_lo , a_0 , 0 );
1050+                 sumi  =  vdotq_laneq_s32 (sumi , b_0_hi , a_1 , 0 );
1051+                 sumi  =  vdotq_laneq_s32 (sumi , b_1_lo , a_0 , 1 );
1052+                 sumi  =  vdotq_laneq_s32 (sumi , b_1_hi , a_1 , 1 );
1053+                 sumi  =  vdotq_laneq_s32 (sumi , b_2_lo , a_0 , 2 );
1054+                 sumi  =  vdotq_laneq_s32 (sumi , b_2_hi , a_1 , 2 );
1055+                 sumi  =  vdotq_laneq_s32 (sumi , b_3_lo , a_0 , 3 );
1056+                 sumi  =  vdotq_laneq_s32 (sumi , b_3_hi , a_1 , 3 );
1057+ 
1058+                 float32x4_t  a_d  =  vcvt_f32_f16 (vld1_dup_f16 ((const  float16_t  * )& a_ptr [l ].d ));
1059+                 float32x4_t  b_d  =  vcvt_f32_f16 (vld1_f16 ((const  float16_t  * )b_ptr [l ].d ));
1060+                 float32x4_t  d  =  a_d  *  b_d ;
1061+ 
1062+                 sumf  =  vmlaq_f32 (sumf , d , vcvtq_f32_s32 (sumi ));
1063+             }
1064+ 
1065+             vst1q_f32 (res_ptr  +  x  *  4 , sumf );
1066+         }
1067+         return ;
1068+     }
1069+ #endif  // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) 
1070+     {
1071+         float  sumf [4 ];
1072+         int  sumi ;
1073+ 
1074+         const  block_q8_0  *  a_ptr  =  (const  block_q8_0  * ) vy ;
1075+         for  (int  x  =  0 ; x  <  nc  / ncols_interleaved ; x ++ ) {
1076+             const  block_iq4_nlx4  *  b_ptr  =  (const  block_iq4_nlx4  * ) vx  +  (x  *  nb );
1077+ 
1078+             for  (int  j  =  0 ; j  <  ncols_interleaved ; j ++ ) sumf [j ] =  0.0 ;
1079+             for  (int  l  =  0 ; l  <  nb ; l ++ ) {
1080+                 for  (int  k  =  0 ; k  <  (qk  / (2  *  blocklen )); k ++ ) {
1081+                     for  (int  j  =  0 ; j  <  ncols_interleaved ; j ++ ) {
1082+                         sumi  =  0 ;
1083+                         for  (int  i  =  0 ; i  <  blocklen ; ++ i ) {
1084+                             const  int  v0  =  kvalues_iq4nl [b_ptr [l ].qs [k  *  ncols_interleaved  *  blocklen  +  j  *  blocklen  +  i ] &  0x0F ];
1085+                             const  int  v1  =  kvalues_iq4nl [b_ptr [l ].qs [k  *  ncols_interleaved  *  blocklen  +  j  *  blocklen  +  i ] >> 4 ];
1086+                             sumi  +=  ((v0  *  a_ptr [l ].qs [k  *  blocklen  +  i ]) +  (v1  *  a_ptr [l ].qs [k  *  blocklen  +  i  +  qk  / 2 ]));
1087+                         }
1088+                         sumf [j ] +=  sumi  *  GGML_FP16_TO_FP32 (b_ptr [l ].d [j ]) *  GGML_FP16_TO_FP32 (a_ptr [l ].d );
1089+                     }
1090+                 }
1091+             }
1092+             for  (int  j  =  0 ; j  <  ncols_interleaved ; j ++ ) s [x  *  ncols_interleaved  +  j ] =  sumf [j ];
1093+         }
1094+     }
1095+ }
1096+ 
9991097void  ggml_gemm_q4_0_4x4_q8_0 (int  n , float  *  restrict s , size_t  bs , const  void  *  restrict vx , const  void  *  restrict vy , int  nr , int  nc ) {
10001098    const  int  qk  =  QK8_0 ;
10011099    const  int  nb  =  n  / qk ;
@@ -3386,6 +3484,117 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
33863484    }
33873485}
33883486
3487+ void  ggml_gemm_iq4_nl_4x4_q8_0 (int  n , float  *  restrict s , size_t  bs , const  void  *  restrict vx , const  void  *  restrict vy , int  nr , int  nc ) {
3488+     const  int  qk  =  QK8_0 ;
3489+     const  int  nb  =  n  / qk ;
3490+     const  int  ncols_interleaved  =  4 ;
3491+     const  int  blocklen  =  4 ;
3492+ 
3493+     assert  (n  % qk  ==  0 );
3494+     assert  (nr  % 4  ==  0 );
3495+     assert  (nc  % ncols_interleaved  ==  0 );
3496+ 
3497+     UNUSED (s );
3498+     UNUSED (bs );
3499+     UNUSED (vx );
3500+     UNUSED (vy );
3501+     UNUSED (nr );
3502+     UNUSED (nc );
3503+     UNUSED (nb );
3504+     UNUSED (ncols_interleaved );
3505+     UNUSED (blocklen );
3506+ 
3507+ #if  ! ((defined(_MSC_VER )) &&  ! defined(__clang__ )) &&  defined(__aarch64__ ) &&  defined(__ARM_NEON )
3508+     if  (ggml_cpu_has_neon ()) {
3509+         const  int8x16_t  kvalues  =  vld1q_s8 (kvalues_iq4nl );
3510+ 
3511+         for  (int  y  =  0 ; y  <  nr  / 4 ; y ++ ) {
3512+             const  block_q8_0x4  *  a_ptr  =  (const  block_q8_0x4  * ) vy  +  (y  *  nb );
3513+             for  (int  x  =  0 ; x  <  nc  / ncols_interleaved ; x ++ ) {
3514+                 const  block_q4_0x4  *  b_ptr  =  (const  block_q4_0x4  * ) vx  +  (x  *  nb );
3515+ 
3516+                 float32x4_t  sumf [4 ];
3517+                 for  (int  m  =  0 ; m  <  4 ; m ++ ) {
3518+                     sumf [m ] =  vdupq_n_f32 (0 );
3519+                 }
3520+ 
3521+                 for  (int  l  =  0 ; l  <  nb ; l ++ ) {
3522+                     float32x4_t  a_d  =  vcvt_f32_f16 (vld1_f16 ((const  float16_t  * )a_ptr [l ].d ));
3523+                     float32x4_t  b_d  =  vcvt_f32_f16 (vld1_f16 ((const  float16_t  * )b_ptr [l ].d ));
3524+ 
3525+                     int32x4_t  sumi_0  =  vdupq_n_s32 (0 );
3526+                     int32x4_t  sumi_1  =  vdupq_n_s32 (0 );
3527+                     int32x4_t  sumi_2  =  vdupq_n_s32 (0 );
3528+                     int32x4_t  sumi_3  =  vdupq_n_s32 (0 );
3529+ 
3530+                     for  (int  k  =  0 ; k  <  4 ; k ++ ) {
3531+                         int8x16_t  a_0  =  vld1q_s8 (a_ptr [l ].qs  +  16  *  k  +  0 );
3532+                         int8x16_t  a_1  =  vld1q_s8 (a_ptr [l ].qs  +  16  *  k  +  64 );
3533+ 
3534+                         uint8x16_t  b  =  vld1q_u8 (b_ptr [l ].qs  +  16  *  k );
3535+                         int8x16_t  b_hi  =  vqtbl1q_s8 (kvalues , b  >> 4 );
3536+                         int8x16_t  b_lo  =  vqtbl1q_s8 (kvalues , b  &  0xF );
3537+ 
3538+                         sumi_0  =  vdotq_laneq_s32 (sumi_0 , b_lo , a_0 , 0 );
3539+                         sumi_1  =  vdotq_laneq_s32 (sumi_1 , b_lo , a_0 , 1 );
3540+                         sumi_2  =  vdotq_laneq_s32 (sumi_2 , b_lo , a_0 , 2 );
3541+                         sumi_3  =  vdotq_laneq_s32 (sumi_3 , b_lo , a_0 , 3 );
3542+                         sumi_0  =  vdotq_laneq_s32 (sumi_0 , b_hi , a_1 , 0 );
3543+                         sumi_1  =  vdotq_laneq_s32 (sumi_1 , b_hi , a_1 , 1 );
3544+                         sumi_2  =  vdotq_laneq_s32 (sumi_2 , b_hi , a_1 , 2 );
3545+                         sumi_3  =  vdotq_laneq_s32 (sumi_3 , b_hi , a_1 , 3 );
3546+                     }
3547+ 
3548+                     sumf [0 ] =  vmlaq_f32 (sumf [0 ], vmulq_laneq_f32 (b_d , a_d , 0 ), vcvtq_f32_s32 (sumi_0 ));
3549+                     sumf [1 ] =  vmlaq_f32 (sumf [1 ], vmulq_laneq_f32 (b_d , a_d , 1 ), vcvtq_f32_s32 (sumi_1 ));
3550+                     sumf [2 ] =  vmlaq_f32 (sumf [2 ], vmulq_laneq_f32 (b_d , a_d , 2 ), vcvtq_f32_s32 (sumi_2 ));
3551+                     sumf [3 ] =  vmlaq_f32 (sumf [3 ], vmulq_laneq_f32 (b_d , a_d , 3 ), vcvtq_f32_s32 (sumi_3 ));
3552+                 }
3553+ 
3554+                 for  (int  m  =  0 ; m  <  4 ; m ++ ) {
3555+                     vst1q_f32 (s  +  (y  *  4  +  m ) *  bs  +  x  *  4 , sumf [m ]);
3556+                 }
3557+             }
3558+         }
3559+         return ;
3560+     }
3561+ #endif  // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) 
3562+     {
3563+         float  sumf [4 ][4 ];
3564+         int  sumi ;
3565+ 
3566+         for  (int  y  =  0 ; y  <  nr  / 4 ; y ++ ) {
3567+             const  block_q8_0x4  *  a_ptr  =  (const  block_q8_0x4  * ) vy  +  (y  *  nb );
3568+             for  (int  x  =  0 ; x  <  nc  / ncols_interleaved ; x ++ ) {
3569+                 const  block_iq4_nlx4  *  b_ptr  =  (const  block_iq4_nlx4  * ) vx  +  (x  *  nb );
3570+                 for  (int  m  =  0 ; m  <  4 ; m ++ ) {
3571+                     for  (int  j  =  0 ; j  <  ncols_interleaved ; j ++ ) sumf [m ][j ] =  0.0 ;
3572+                 }
3573+                 for  (int  l  =  0 ; l  <  nb ; l ++ ) {
3574+                     for  (int  k  =  0 ; k  <  (qk  / (2  *  blocklen )); k ++ ) {
3575+                         for  (int  m  =  0 ; m  <  4 ; m ++ ) {
3576+                             for  (int  j  =  0 ; j  <  ncols_interleaved ; j ++ ) {
3577+                                 sumi  =  0 ;
3578+                                 for  (int  i  =  0 ; i  <  blocklen ; ++ i ) {
3579+                                     const  int  v0  =  kvalues_iq4nl [b_ptr [l ].qs [k  *  ncols_interleaved  *  blocklen  +  j  *  blocklen  +  i ] &  0x0F ];
3580+                                     const  int  v1  =  kvalues_iq4nl [b_ptr [l ].qs [k  *  ncols_interleaved  *  blocklen  +  j  *  blocklen  +  i ] >> 4 ];
3581+                                     sumi  +=  ((v0  *  a_ptr [l ].qs [k  *  4  *  blocklen  +  m  *  blocklen  +  i ]) + 
3582+                                             (v1  *  a_ptr [l ].qs [k  *  4  *  blocklen  +  m  *  blocklen  +  i  +  qk  / 2  *  4 ]));
3583+                                 }
3584+                                 sumf [m ][j ] +=  sumi  *  GGML_FP16_TO_FP32 (b_ptr [l ].d [j ]) *  GGML_FP16_TO_FP32 (a_ptr [l ].d [m ]);
3585+                             }
3586+                         }
3587+                     }
3588+                 }
3589+                 for  (int  m  =  0 ; m  <  4 ; m ++ ) {
3590+                     for  (int  j  =  0 ; j  <  ncols_interleaved ; j ++ )
3591+                         s [(y  *  4  +  m ) *  bs  +  x  *  ncols_interleaved  +  j ] =  sumf [m ][j ];
3592+                 }
3593+             }
3594+         }
3595+     }
3596+ }
3597+ 
33893598// FIXME: this code is duplicated from ggml-aarch64.c 
33903599static  block_q4_0x4  make_block_q4_0x4 (block_q4_0  *  in , unsigned int   blck_size_interleave ) {
33913600    block_q4_0x4  out ;
@@ -3518,27 +3727,101 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block,
35183727    GGML_UNUSED (data_size );
35193728}
35203729
3730+ static  block_iq4_nlx4  make_block_iq4_nlx4 (block_iq4_nl  *  in , unsigned int   blck_size_interleave ) {
3731+     block_iq4_nlx4  out ;
3732+ 
3733+     for  (int  i  =  0 ; i  <  4 ; i ++ ) {
3734+         out .d [i ] =  in [i ].d ;
3735+     }
3736+ 
3737+     const  int  end  =  QK4_NL  *  2  / blck_size_interleave ;
3738+ 
3739+     if  (blck_size_interleave  ==  8 ) {
3740+         for  (int  i  =  0 ; i  <  end ; ++ i ) {
3741+             int  src_id  =  i  % 4 ;
3742+             int  src_offset  =  (i  / 4 ) *  blck_size_interleave ;
3743+             int  dst_offset  =  i  *  blck_size_interleave ;
3744+ 
3745+             // Using memcpy to avoid unaligned memory accesses 
3746+             memcpy (& out .qs [dst_offset ], & in [src_id ].qs [src_offset ], sizeof (uint64_t ));
3747+         }
3748+     } else  if  (blck_size_interleave  ==  4 ) {
3749+         for  (int  i  =  0 ; i  <  end ; ++ i ) {
3750+             int  src_id  =  i  % 4 ;
3751+             int  src_offset  =  (i  / 4 ) *  blck_size_interleave ;
3752+             int  dst_offset  =  i  *  blck_size_interleave ;
3753+ 
3754+             memcpy (& out .qs [dst_offset ], & in [src_id ].qs [src_offset ], sizeof (uint32_t ));
3755+         }
3756+     } else  {
3757+         GGML_ASSERT (false);
3758+     }
3759+ 
3760+     return  out ;
3761+ }
3762+ 
3763+ static  int  repack_iq4_nl_to_iq4_nl_4_bl (struct  ggml_tensor  *  t , int  interleave_block , const  void  *  restrict data , size_t  data_size ) {
3764+     GGML_ASSERT (t -> type  ==  GGML_TYPE_IQ4_NL );
3765+     GGML_ASSERT (interleave_block  ==  4  ||  interleave_block  ==  8 );
3766+ 
3767+     block_iq4_nlx4  *  dst  =  (block_iq4_nlx4  * )t -> data ;
3768+     const  block_iq4_nl  *  src  =  (const  block_iq4_nl  * )data ;
3769+     block_iq4_nl  dst_tmp [4 ];
3770+     int  nrow  =  t -> ne [1 ]; // Number of rows 
3771+     int  nrows_interleaved  =  4 ;
3772+     int  nblocks  =  t -> ne [0 ] / QK4_0 ;
3773+ 
3774+     GGML_ASSERT (data_size  ==  nrow  *  nblocks  *  sizeof (block_iq4_nl ));
3775+ 
3776+     if  (nrow  % nrows_interleaved  !=  0  ||  t -> ne [0 ] % 8  !=  0 ) {
3777+         return  -1 ;
3778+     }
3779+ 
3780+     for  (int  b  =  0 ; b  <  nrow ; b  +=  nrows_interleaved ) {
3781+         for  (int64_t  x  =  0 ; x  <  nblocks ; x ++ ) {
3782+             for  (int  i  =  0 ; i  <  nrows_interleaved ; i ++ ) {
3783+                 dst_tmp [i ] =  src [x  +  i  *  nblocks ];
3784+             }
3785+             * dst ++  =  make_block_iq4_nlx4 (dst_tmp , interleave_block );
3786+         }
3787+         src  +=  nrows_interleaved  *  nblocks ;
3788+     }
3789+     return  0 ;
3790+ 
3791+     GGML_UNUSED (data_size );
3792+ }
3793+ 
35213794// Prepare for optimized kernels if applicable 
35223795void  ggml_aarch64_repack_tensor (struct  ggml_tensor  *  cur , enum  ggml_type  repack_type , const  void  *  restrict data , size_t  data_size ) {
35233796    if  (cur -> type  ==  repack_type ) {
35243797        memcpy (cur -> data , data , data_size );
35253798        return ;
35263799    }
35273800
3528-     GGML_ASSERT (cur -> type  ==  GGML_TYPE_Q4_0 );
3529- 
3530-     switch  (repack_type ) {
3531-         case  GGML_TYPE_Q4_0_8_8 :
3532-             repack_q4_0_to_q4_0_8_bl (cur , 8 , data , data_size );
3533-             break ;
3534-         case  GGML_TYPE_Q4_0_4_8 :
3535-             repack_q4_0_to_q4_0_4_bl (cur , 8 , data , data_size );
3536-             break ;
3537-         case  GGML_TYPE_Q4_0_4_4 :
3538-             repack_q4_0_to_q4_0_4_bl (cur , 4 , data , data_size );
3539-             break ;
3540-         default :
3541-             GGML_ABORT ("Unsupported type" );
3801+     if  (cur -> type  ==  GGML_TYPE_Q4_0 ) {
3802+         switch  (repack_type ) {
3803+             case  GGML_TYPE_Q4_0_8_8 :
3804+                 repack_q4_0_to_q4_0_8_bl (cur , 8 , data , data_size );
3805+                 break ;
3806+             case  GGML_TYPE_Q4_0_4_8 :
3807+                 repack_q4_0_to_q4_0_4_bl (cur , 8 , data , data_size );
3808+                 break ;
3809+             case  GGML_TYPE_Q4_0_4_4 :
3810+                 repack_q4_0_to_q4_0_4_bl (cur , 4 , data , data_size );
3811+                 break ;
3812+             default :
3813+                 GGML_ABORT ("Unsupported type" );
3814+         }
3815+     } else  if  (cur -> type  ==  GGML_TYPE_IQ4_NL ) {
3816+         switch  (repack_type ) {
3817+             case  GGML_TYPE_IQ4_NL_4_4 :
3818+                 repack_iq4_nl_to_iq4_nl_4_bl (cur , 4 , data , data_size );
3819+                 break ;
3820+             default :
3821+                 GGML_ABORT ("Unsupported type" );
3822+         }
3823+     } else  {
3824+         GGML_ABORT ("Unsupported type" );
35423825    }
35433826}
35443827
@@ -3554,6 +3837,10 @@ enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * c
35543837        if  (ggml_cpu_has_neon ()) {
35553838            return  GGML_TYPE_Q4_0_4_4 ;
35563839        }
3840+     } else  if  (cur -> type  ==  GGML_TYPE_IQ4_NL ) {
3841+         if  (ggml_cpu_has_neon ()) {
3842+             return  GGML_TYPE_IQ4_NL_4_4 ;
3843+         }
35573844    }
35583845
35593846    return  cur -> type ;
0 commit comments