@@ -9107,9 +9107,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91079107 const __m128i m4 = _mm_set1_epi8(0xF);
91089108 const __m128i m3 = _mm_set1_epi8(3);
91099109 const __m128i m32s = _mm_set1_epi8(32);
9110+ const __m128i m2 = _mm_set1_epi8(2);
91109111
9111- __m256 acc1 = _mm256_setzero_ps();
9112- __m256 acc2 = _mm256_setzero_ps();
9112+ __m256 acc = _mm256_setzero_ps();
91139113
91149114 for (int i = 0; i < nb; ++i) {
91159115
@@ -9123,8 +9123,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91239123
91249124 __m128i sumi_0 = _mm_setzero_si128();
91259125 __m128i sumi_1 = _mm_setzero_si128();
9126- __m128i sumi_2 = _mm_setzero_si128();
9127- __m128i sumi_3 = _mm_setzero_si128();
91289126
91299127 __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
91309128 for (int j = 0; j < QK_K/128; ++j) {
@@ -9134,90 +9132,75 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91349132
91359133 const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
91369134 const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9135+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9136+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
91379137 const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
91389138 const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
9139+ const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9140+ const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
91399141
91409142 const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91419143 const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9144+ const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9145+ const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9146+
9147+ const __m128i q4_0 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0), m32s);
9148+ const __m128i q4_1 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1), m32s);
9149+ const __m128i q4_2 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2), m32s);
9150+ const __m128i q4_3 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3), m32s);
9151+ const __m128i q4_4 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4), m32s);
9152+ const __m128i q4_5 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5), m32s);
9153+ const __m128i q4_6 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6), m32s);
9154+ const __m128i q4_7 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7), m32s);
91429155
9143- const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
9144- const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
9145- const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
9146- const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
91479156 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9148- const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 48;
9157+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9158+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9159+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91499160 const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9150- const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 -= 48;
9151- __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
9152- __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
9153- __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
9154- __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
9155- __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
9156- __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
9157- __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
9158- __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
9159- p16_0 = _mm_sub_epi16(p16_0, q8s_0);
9160- p16_1 = _mm_sub_epi16(p16_1, q8s_1);
9161- p16_4 = _mm_sub_epi16(p16_4, q8s_4);
9162- p16_5 = _mm_sub_epi16(p16_5, q8s_5);
9161+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9162+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9163+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9164+
9165+ __m128i p16_0 = mul_add_epi8_sse(q4_0, q8_0);
9166+ __m128i p16_1 = mul_add_epi8_sse(q4_1, q8_1);
9167+ __m128i p16_2 = mul_add_epi8_sse(q4_2, q8_2);
9168+ __m128i p16_3 = mul_add_epi8_sse(q4_3, q8_3);
9169+ __m128i p16_4 = mul_add_epi8_sse(q4_4, q8_4);
9170+ __m128i p16_5 = mul_add_epi8_sse(q4_5, q8_5);
9171+ __m128i p16_6 = mul_add_epi8_sse(q4_6, q8_6);
9172+ __m128i p16_7 = mul_add_epi8_sse(q4_7, q8_7);
91639173
91649174 const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
9165- const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(4)));
9175+ shuffle = _mm_add_epi8(shuffle, m2);
9176+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
9177+ shuffle = _mm_add_epi8(shuffle, m2);
9178+ const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
9179+ shuffle = _mm_add_epi8(shuffle, m2);
9180+ const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
9181+ shuffle = _mm_add_epi8(shuffle, m2);
9182+
91669183 p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
91679184 p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
9168- p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9169- p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
9170-
9171- sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_1));
9172- sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_5));
9173-
9174- const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9175- const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9176- const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9177- const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
9178- const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9179- const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
9180- const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
9181- const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
9182- const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
9183- const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
9184-
9185- const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9186- const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 48;
9187- const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9188- const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16 ;
9189- __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
9190- __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
9191- __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
9192- __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
9193- __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
9194- __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
9195- __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
9196- __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
9197- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
9198- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
9199- p16_6 = _mm_sub_epi16(p16_6, q8s_6);
9200- p16_7 = _mm_sub_epi16(p16_7, q8s_7);
9201-
9202- const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(2)));
9203- const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(6)));
92049185 p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
92059186 p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
9187+ p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9188+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
92069189 p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
92079190 p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
92089191
9209- sumi_2 = _mm_add_epi32(sumi_2, _mm_add_epi32(p16_2, p16_3));
9210- sumi_3 = _mm_add_epi32(sumi_3, _mm_add_epi32(p16_6, p16_7));
9192+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
9193+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
9194+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
9195+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
92119196
92129197 }
92139198
9214- __m256i sumi1 = MM256_SET_M128I(sumi_0, sumi_1);
9215- __m256i sumi2 = MM256_SET_M128I(sumi_2, sumi_3);
9216- acc1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi1)), acc1);
9217- acc2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi2)), acc2);
9199+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9200+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
92189201 }
92199202
9220- *s = hsum_float_8(_mm256_add_ps(acc1, acc2) );
9203+ *s = hsum_float_8(acc );
92219204
92229205#elif defined __riscv_v_intrinsic
92239206
0 commit comments