@@ -9133,10 +9133,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91339133
91349134#elif defined __AVX__
91359135
9136- const __m128i m4 = _mm_set1_epi8(0xF);
91379136 const __m128i m3 = _mm_set1_epi8(3);
9138- const __m128i m32s = _mm_set1_epi8(32);
9139- const __m128i m2 = _mm_set1_epi8(2);
9137+ const __m128i m15 = _mm_set1_epi8(15);
91409138
91419139 __m256 acc = _mm256_setzero_ps();
91429140
@@ -9148,39 +9146,47 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91489146 const uint8_t * restrict qh = x[i].qh;
91499147 const int8_t * restrict q8 = y[i].qs;
91509148
9149+ // handle the q6_k -32 offset separately using bsums
9150+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
9151+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
91519152 const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
9153+ const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
9154+ const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
9155+ const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
9156+ const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
91529157
91539158 __m128i sumi_0 = _mm_setzero_si128();
91549159 __m128i sumi_1 = _mm_setzero_si128();
91559160
9156- __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
9161+ int is = 0;
9162+
91579163 for (int j = 0; j < QK_K/128; ++j) {
91589164
91599165 const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
91609166 const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
91619167
91629168 const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
91639169 const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9164- const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16( q4bitsH_0, 2), m3 ), 4 );
9165- const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16( q4bitsH_1, 2), m3 ), 4 );
9166- const __m128i q4h_4 = _mm_slli_epi16( _mm_and_si128(_mm_srli_epi16( q4bitsH_0, 4), m3), 4 );
9167- const __m128i q4h_5 = _mm_slli_epi16( _mm_and_si128(_mm_srli_epi16( q4bitsH_1, 4), m3), 4 );
9168- const __m128i q4h_6 = _mm_slli_epi16 (_mm_and_si128(_mm_srli_epi16( q4bitsH_0, 6), m3 ), 4 );
9169- const __m128i q4h_7 = _mm_slli_epi16 (_mm_and_si128(_mm_srli_epi16( q4bitsH_1, 6), m3 ), 4 );
9170+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12) ), 2 );
9171+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12) ), 2 );
9172+ const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48) );
9173+ const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48) );
9174+ const __m128i q4h_6 = _mm_srli_epi16 (_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64) ), 2 );
9175+ const __m128i q4h_7 = _mm_srli_epi16 (_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64) ), 2 );
91709176
91719177 const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91729178 const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91739179 const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91749180 const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91759181
9176- const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4 ), q4h_0);
9177- const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4 ), q4h_1);
9178- const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4 ), q4h_2);
9179- const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4 ), q4h_3);
9180- const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4 ), q4h_4);
9181- const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4 ), q4h_5);
9182- const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4 ), q4h_6);
9183- const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4 ), q4h_7);
9182+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15 ), q4h_0);
9183+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15 ), q4h_1);
9184+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15 ), q4h_2);
9185+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15 ), q4h_3);
9186+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15 ), q4h_4);
9187+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15 ), q4h_5);
9188+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15 ), q4h_6);
9189+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15 ), q4h_7);
91849190
91859191 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91869192 const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@@ -9191,15 +9197,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91919197 const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91929198 const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91939199
9194- __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
9195- __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
9196- __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
9197- __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
9198- __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
9199- __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
9200- __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
9201- __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
9202-
92039200 __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
92049201 __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
92059202 __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
@@ -9209,32 +9206,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
92099206 __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
92109207 __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
92119208
9212- p16_0 = _mm_sub_epi16(p16_0, q8s_0);
9213- p16_1 = _mm_sub_epi16(p16_1, q8s_1);
9214- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
9215- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
9216- p16_4 = _mm_sub_epi16(p16_4, q8s_4);
9217- p16_5 = _mm_sub_epi16(p16_5, q8s_5);
9218- p16_6 = _mm_sub_epi16(p16_6, q8s_6);
9219- p16_7 = _mm_sub_epi16(p16_7, q8s_7);
9220-
9221- const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
9222- shuffle = _mm_add_epi8(shuffle, m2);
9223- const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
9224- shuffle = _mm_add_epi8(shuffle, m2);
9225- const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
9226- shuffle = _mm_add_epi8(shuffle, m2);
9227- const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
9228- shuffle = _mm_add_epi8(shuffle, m2);
9209+ const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
9210+ const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
9211+ const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
9212+ const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
9213+ is += 4;
92299214
92309215 p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
9231- p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64 (scale_0, scale_0 )), p16_1);
9216+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128 (scale_0, 8 )), p16_1);
92329217 p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
9233- p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64 (scale_1, scale_1 )), p16_3);
9218+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128 (scale_1, 8 )), p16_3);
92349219 p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9235- p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64 (scale_2, scale_2 )), p16_5);
9220+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128 (scale_2, 8 )), p16_5);
92369221 p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
9237- p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64 (scale_3, scale_3 )), p16_7);
9222+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128 (scale_3, 8 )), p16_7);
92389223
92399224 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
92409225 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
@@ -9243,8 +9228,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
92439228
92449229 }
92459230
9246- __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9247- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
9231+ sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
9232+ sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
9233+ const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9234+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
92489235 }
92499236
92509237 *s = hsum_float_8(acc);
0 commit comments