@@ -9118,14 +9118,13 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91189118 const int8_t * restrict q8 = y[i].qs;
91199119
91209120 // handle the q6_k -32 offset separately using bsums
9121- // TODO: tabs, compiler warnings for earlier code
91229121 const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
91239122 const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
91249123 const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
91259124 const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
91269125 const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
9127- const __m128i q8scld_0 = _mm_mullo_epi32 (_mm_madd_epi16(q8sums_0, scales_16_0), _mm_set1_epi32(32) );
9128- const __m128i q8scld_1 = _mm_mullo_epi32 (_mm_madd_epi16(q8sums_1, scales_16_1), _mm_set1_epi32(32) );
9126+ const __m128i q8sclsub_0 = _mm_slli_epi32 (_mm_madd_epi16(q8sums_0, scales_16_0), 5 );
9127+ const __m128i q8sclsub_1 = _mm_slli_epi32 (_mm_madd_epi16(q8sums_1, scales_16_1), 5 );
91299128
91309129 __m128i sumi_0 = _mm_setzero_si128();
91319130 __m128i sumi_1 = _mm_setzero_si128();
@@ -9139,12 +9138,12 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91399138
91409139 const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
91419140 const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9142- const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(0x0C )), 2);
9143- const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(0x0C )), 2);
9144- const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(0x30 ));
9145- const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(0x30 ));
9146- const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(0xC0 )), 2);
9147- const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(0xC0 )), 2);
9141+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12 )), 2);
9142+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12 )), 2);
9143+ const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48 ));
9144+ const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48 ));
9145+ const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64 )), 2);
9146+ const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64 )), 2);
91489147
91499148 const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91509149 const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
@@ -9185,22 +9184,22 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91859184 is += 4;
91869185
91879186 p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
9188- p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64 (scale_0, scale_0 )), p16_1);
9187+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128 (scale_0, 8 )), p16_1);
91899188 p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
9190- p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64 (scale_1, scale_1 )), p16_3);
9189+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128 (scale_1, 8 )), p16_3);
91919190 p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9192- p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64 (scale_2, scale_2 )), p16_5);
9191+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128 (scale_2, 8 )), p16_5);
91939192 p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
9194- p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64 (scale_3, scale_3 )), p16_7);
9193+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128 (scale_3, 8 )), p16_7);
91959194
91969195 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
91979196 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
91989197 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
91999198 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
92009199 }
92019200
9202- sumi_0 = _mm_sub_epi32(sumi_0, q8scld_0 );
9203- sumi_1 = _mm_sub_epi32(sumi_1, q8scld_1 );
9201+ sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0 );
9202+ sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1 );
92049203 const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
92059204 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
92069205 }
0 commit comments