@@ -9104,10 +9104,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91049104
91059105#elif defined __AVX__
91069106
9107- const __m128i m2 = _mm_set1_epi8(2);
91089107 const __m128i m3 = _mm_set1_epi8(3);
91099108 const __m128i m15 = _mm_set1_epi8(15);
9110- const __m128i m32 = _mm_set1_epi8(32);
91119109
91129110 __m256 acc = _mm256_setzero_ps();
91139111
@@ -9119,7 +9117,15 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91199117 const uint8_t * restrict qh = x[i].qh;
91209118 const int8_t * restrict q8 = y[i].qs;
91219119
9120+ // handle the q6_k -32 offset separately using bsums
9121+ // TODO: tabs, compiler warnings for earlier code
9122+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
9123+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
91229124 const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
9125+ const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
9126+ const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
9127+ const __m128i q8scld_0 = _mm_mullo_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), _mm_set1_epi32(32));
9128+ const __m128i q8scld_1 = _mm_mullo_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), _mm_set1_epi32(32));
91239129
91249130 __m128i sumi_0 = _mm_setzero_si128();
91259131 __m128i sumi_1 = _mm_setzero_si128();
@@ -9145,14 +9151,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91459151 const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91469152 const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91479153
9148- const __m128i q4_0 = _mm_sub_epi8( _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0), m32 );
9149- const __m128i q4_1 = _mm_sub_epi8( _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1), m32 );
9150- const __m128i q4_2 = _mm_sub_epi8( _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2), m32 );
9151- const __m128i q4_3 = _mm_sub_epi8( _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3), m32 );
9152- const __m128i q4_4 = _mm_sub_epi8( _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4), m32 );
9153- const __m128i q4_5 = _mm_sub_epi8( _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5), m32 );
9154- const __m128i q4_6 = _mm_sub_epi8( _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6), m32 );
9155- const __m128i q4_7 = _mm_sub_epi8( _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7), m32 );
9154+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
9155+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
9156+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
9157+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
9158+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
9159+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
9160+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
9161+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
91569162
91579163 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91589164 const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@@ -9163,14 +9169,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91639169 const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91649170 const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91659171
9166- __m128i p16_0 = mul_add_epi8_sse (q4_0, q8_0);
9167- __m128i p16_1 = mul_add_epi8_sse (q4_1, q8_1);
9168- __m128i p16_2 = mul_add_epi8_sse (q4_2, q8_2);
9169- __m128i p16_3 = mul_add_epi8_sse (q4_3, q8_3);
9170- __m128i p16_4 = mul_add_epi8_sse (q4_4, q8_4);
9171- __m128i p16_5 = mul_add_epi8_sse (q4_5, q8_5);
9172- __m128i p16_6 = mul_add_epi8_sse (q4_6, q8_6);
9173- __m128i p16_7 = mul_add_epi8_sse (q4_7, q8_7);
9172+ __m128i p16_0 = _mm_maddubs_epi16 (q4_0, q8_0);
9173+ __m128i p16_1 = _mm_maddubs_epi16 (q4_1, q8_1);
9174+ __m128i p16_2 = _mm_maddubs_epi16 (q4_2, q8_2);
9175+ __m128i p16_3 = _mm_maddubs_epi16 (q4_3, q8_3);
9176+ __m128i p16_4 = _mm_maddubs_epi16 (q4_4, q8_4);
9177+ __m128i p16_5 = _mm_maddubs_epi16 (q4_5, q8_5);
9178+ __m128i p16_6 = _mm_maddubs_epi16 (q4_6, q8_6);
9179+ __m128i p16_7 = _mm_maddubs_epi16 (q4_7, q8_7);
91749180
91759181 const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
91769182 const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
@@ -9191,10 +9197,11 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91919197 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
91929198 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
91939199 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
9194-
91959200 }
91969201
9197- __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9202+ sumi_0 = _mm_sub_epi32(sumi_0, q8scld_0);
9203+ sumi_1 = _mm_sub_epi32(sumi_1, q8scld_1);
9204+ const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
91989205 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
91999206 }
92009207
0 commit comments