@@ -532,12 +532,12 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
532532 int offset = ((k / 2 ) % 2 ) + j * 2 ;
533533 for (int i = 0 ; i < blocklen; ++i) {
534534 const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i;
535- const int lbits_index = ( hbits_index / 32 ) * 64 + (hbits_index % 32 ) ;
535+ const int lbits_index = hbits_index + (k/ 4 ) * 256 ;
536536
537- int8_t v0 = (int8_t )((b_ptr[l].qh [hbits_index] & 3 ) << 4 ) | (b_ptr[l].ql [lbits_index] & 0xF ) - 32 ;
538- int8_t v1 = (int8_t )(((b_ptr[l].qh [hbits_index] >> 2 ) & 3 ) << 4 ) | (b_ptr[l].ql [lbits_index + 32 ] & 0xF ) - 32 ;
539- int8_t v2 = (int8_t )(((b_ptr[l].qh [hbits_index] >> 4 ) & 3 ) << 4 ) | ((b_ptr[l].ql [lbits_index] >> 4 ) & 0xF ) - 32 ;
540- int8_t v3 = (int8_t )(((b_ptr[l].qh [hbits_index] >> 6 ) & 3 ) << 4 ) | ((b_ptr[l].ql [lbits_index + 32 ] >> 4 ) & 0xF ) - 32 ;
537+ int8_t v0 = (int8_t )((( b_ptr[l].qh [hbits_index] & 3 ) << 4 ) | (b_ptr[l].ql [lbits_index] & 0xF ) ) - 32 ;
538+ int8_t v1 = (int8_t )(((( b_ptr[l].qh [hbits_index] >> 2 ) & 3 ) << 4 ) | (b_ptr[l].ql [lbits_index + 256 ] & 0xF ) ) - 32 ;
539+ int8_t v2 = (int8_t )(((( b_ptr[l].qh [hbits_index] >> 4 ) & 3 ) << 4 ) | ((b_ptr[l].ql [lbits_index] >> 4 ) & 0xF ) ) - 32 ;
540+ int8_t v3 = (int8_t )(((( b_ptr[l].qh [hbits_index] >> 6 ) & 3 ) << 4 ) | ((b_ptr[l].ql [lbits_index + 256 ] >> 4 ) & 0xF ) ) - 32 ;
541541
542542 sumi1 = (v0 * a_ptr[l].qs [(k >> 2 ) * 128 + (k % 4 ) * blocklen + i]);
543543 sumi2 = (v1 * a_ptr[l].qs [(k >> 2 ) * 128 + (k % 4 ) * blocklen + i + 32 ]);
@@ -556,6 +556,7 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
556556 }
557557 for (int j = 0 ; j < ncols_interleaved; j++) {
558558 s[x * ncols_interleaved + j] = sumf[j];
559+
559560 }
560561 }
561562}
@@ -999,21 +1000,21 @@ void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
9991000 int offset = ((k / 2 ) % 2 ) + j * 2 ;
10001001 for (int i = 0 ; i < blocklen; ++i){
10011002 const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i;
1002- const int lbits_index = ( hbits_index / 32 ) * 64 + (hbits_index % 32 ) ;
1003+ const int lbits_index = hbits_index + (k/ 4 ) * 256 ;
10031004
1004- int8_t v0 = (int8_t )((b_ptr[l].qh [hbits_index] & 3 ) << 4 ) | (b_ptr[l].ql [lbits_index] & 0xF ) - 32 ;
1005- int8_t v1 = (int8_t )(((b_ptr[l].qh [hbits_index] >> 2 ) & 3 ) << 4 ) | (b_ptr[l].ql [lbits_index + 32 ] & 0xF ) - 32 ;
1006- int8_t v2 = (int8_t )(((b_ptr[l].qh [hbits_index] >> 4 ) & 3 ) << 4 ) | ((b_ptr[l].ql [lbits_index] >> 4 ) & 0xF ) - 32 ;
1007- int8_t v3 = (int8_t )(((b_ptr[l].qh [hbits_index] >> 6 ) & 3 ) << 4 ) | ((b_ptr[l].ql [lbits_index + 32 ] >> 4 ) & 0xF ) - 32 ;
1005+ int8_t v0 = (int8_t )((( b_ptr[l].qh [hbits_index] & 3 ) << 4 ) | (b_ptr[l].ql [lbits_index] & 0xF ) ) - 32 ;
1006+ int8_t v1 = (int8_t )(((( b_ptr[l].qh [hbits_index] >> 2 ) & 3 ) << 4 ) | (b_ptr[l].ql [lbits_index + 256 ] & 0xF ) ) - 32 ;
1007+ int8_t v2 = (int8_t )(((( b_ptr[l].qh [hbits_index] >> 4 ) & 3 ) << 4 ) | ((b_ptr[l].ql [lbits_index] >> 4 ) & 0xF ) ) - 32 ;
1008+ int8_t v3 = (int8_t )(((( b_ptr[l].qh [hbits_index] >> 6 ) & 3 ) << 4 ) | ((b_ptr[l].ql [lbits_index + 256 ] >> 4 ) & 0xF ) ) - 32 ;
10081009
10091010 sumi1 = (v0 * a_ptr[l].qs [(k >> 2 ) * 512 + (k % 4 ) * 4 * blocklen + m * blocklen + i]);
10101011 sumi2 = (v1 * a_ptr[l].qs [(k >> 2 ) * 512 + (k % 4 ) * 4 * blocklen + m * blocklen + i + 128 ]);
10111012 sumi3 = (v2 * a_ptr[l].qs [(k >> 2 ) * 512 + (k % 4 ) * 4 * blocklen + m * blocklen + i + 256 ]);
10121013 sumi4 = (v3 * a_ptr[l].qs [(k >> 2 ) * 512 + (k % 4 ) * 4 * blocklen + m * blocklen + i + 384 ]);
1013- sumi1 = sumi1 * (scales_0[offset] & 0xF );
1014- sumi2 = sumi2 * (scales_1[offset] & 0xF );
1015- sumi3 = sumi3 * (scales_2[offset] & 0xF );
1016- sumi4 = sumi4 * (scales_3[offset] & 0xF );
1014+ sumi1 = sumi1 * (scales_0[offset]);
1015+ sumi2 = sumi2 * (scales_1[offset]);
1016+ sumi3 = sumi3 * (scales_2[offset]);
1017+ sumi4 = sumi4 * (scales_3[offset]);
10171018 sumi += sumi1 + sumi2 + sumi3 + sumi4;
10181019 }
10191020 sumf[m][j] += sumi * GGML_FP16_TO_FP32 (b_ptr[l].d [j]) * a_ptr[l].d [m];
0 commit comments