@@ -346,8 +346,8 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
346346#endif  //  MMQ_SHMEM
347347#endif 
348348
349- #if  defined(DATA_A_Q4_K)
350- //  4-byte loads for Q4_K blocks (144 bytes)
349+ #if  defined(DATA_A_Q4_K)  ||  defined(DATA_A_Q5_K) 
350+ //  4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) 
351351ACC_TYPE mul_q8_1(const  int32_t q_sum, const  vec2  dma, const  vec2  dsb, const  int32_t sum_divisor) {
352352    return  ACC_TYPE(dsb.x *  dma.x *  float (q_sum) -  dma.y *  dsb.y);
353353}
@@ -361,10 +361,19 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
361361    const  uint  qs_shift =  ((iqs_k %  16 ) /  8 ) *  4 ;
362362
363363    //  Repack 2x4 quants into one int
364+ #if  defined(DATA_A_Q4_K)
364365    const  uint32_t vals0 =  (data_a_packed32[ib_k].qs[qs_idx    ] >>  qs_shift) &  0x0F0F0F0F;
365366    const  uint32_t vals1 =  (data_a_packed32[ib_k].qs[qs_idx +  1 ] >>  qs_shift) &  0x0F0F0F0F;
366367
367368    buf_a[buf_ib].qs[iqs] =  vals0 |  (vals1 <<  4 );
369+ #else  //  defined(DATA_A_Q5_K)
370+     const  uint  qh_idx =  iqs *  QUANT_R_MMQ;
371+     const  uint  qh_shift =  iqs_k /  8 ;
372+ 
373+     buf_a[buf_ib].qs[iqs] =  int32_t(((data_a_packed32[ib_k].qs[qs_idx] >>  qs_shift) &  0x0F0F0F0F) | 
374+                                    (((data_a_packed32[ib_k].qh[qh_idx] >>  qh_shift) &  0x01010101) <<  4 ));
375+ #endif 
376+ 
368377
369378    if  (iqs ==  0 ) {
370379        //  Scale index
@@ -384,7 +393,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
384393void  block_a_to_registers(const  uint  reg_ib, const  uint  buf_ib) {
385394    cache_a[reg_ib].dm =  buf_a[buf_ib].dm;
386395
387-     [[unroll]] for  (uint  iqs =  0 ; iqs <  4 ; iqs++ ) {
396+     [[unroll]] for  (uint  iqs =  0 ; iqs <  8   /  QUANT_R_MMQ ; iqs++ ) {
388397        cache_a[reg_ib].qs[iqs] =  buf_a[buf_ib].qs[iqs];
389398    }
390399}
@@ -393,7 +402,11 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
393402    int32_t q_sum =  0 ;
394403
395404    [[unroll]] for  (uint  iqs =  0 ; iqs <  8 ; iqs++ ) {
405+ #if  defined(DATA_A_Q4_K)
396406        const  int32_t qs_a =  int32_t((cache_a[ib_a].qs[iqs /  2 ] >>  ((iqs %  2 ) *  4 )) &  0x0F0F0F0F);
407+ #else  //  defined(DATA_A_Q5_K)
408+         const  int32_t qs_a =  cache_a[ib_a].qs[iqs];
409+ #endif 
397410
398411        q_sum +=  dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
399412    }
0 commit comments