@@ -9107,9 +9107,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91079107 const __m128i m4 = _mm_set1_epi8(0xF);
91089108 const __m128i m3 = _mm_set1_epi8(3);
91099109 const __m128i m32s = _mm_set1_epi8(32);
9110- const __m128i m2 = _mm_set1_epi8(2);
91119110
9112- __m256 acc = _mm256_setzero_ps();
9111+ __m256 acc1 = _mm256_setzero_ps();
9112+ __m256 acc2 = _mm256_setzero_ps();
91139113
91149114 for (int i = 0; i < nb; ++i) {
91159115
@@ -9123,6 +9123,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91239123
91249124 __m128i sumi_0 = _mm_setzero_si128();
91259125 __m128i sumi_1 = _mm_setzero_si128();
9126+ __m128i sumi_2 = _mm_setzero_si128();
9127+ __m128i sumi_3 = _mm_setzero_si128();
91269128
91279129 __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
91289130 for (int j = 0; j < QK_K/128; ++j) {
@@ -9132,93 +9134,90 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91329134
91339135 const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
91349136 const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9135- const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9136- const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
91379137 const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
91389138 const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
9139- const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9140- const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
91419139
91429140 const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91439141 const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9144- const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9145- const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91469142
91479143 const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
91489144 const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
9149- const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
9150- const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
91519145 const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
91529146 const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
9153- const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
9154- const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
9155-
91569147 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9157- const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9158- const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9159- const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9148+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 48;
91609149 const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9161- const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9162- const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9163- const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9164-
9150+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 -= 48;
91659151 __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
91669152 __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
9167- __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
9168- __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
91699153 __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
91709154 __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
9171- __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
9172- __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
9173-
91749155 __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
91759156 __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
9176- __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
9177- __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
91789157 __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
91799158 __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
9180- __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
9181- __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
9182-
91839159 p16_0 = _mm_sub_epi16(p16_0, q8s_0);
91849160 p16_1 = _mm_sub_epi16(p16_1, q8s_1);
9185- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
9186- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
91879161 p16_4 = _mm_sub_epi16(p16_4, q8s_4);
91889162 p16_5 = _mm_sub_epi16(p16_5, q8s_5);
9189- p16_6 = _mm_sub_epi16(p16_6, q8s_6);
9190- p16_7 = _mm_sub_epi16(p16_7, q8s_7);
91919163
91929164 const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
9193- shuffle = _mm_add_epi8(shuffle, m2);
9194- const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
9195- shuffle = _mm_add_epi8(shuffle, m2);
9196- const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
9197- shuffle = _mm_add_epi8(shuffle, m2);
9198- const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
9199- shuffle = _mm_add_epi8(shuffle, m2);
9200-
9165+ const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(4)));
92019166 p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
92029167 p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
9203- p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
9204- p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
92059168 p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
92069169 p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
9170+
9171+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_1));
9172+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_5));
9173+
9174+ const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9175+ const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9176+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9177+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
9178+ const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9179+ const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
9180+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
9181+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
9182+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
9183+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
9184+
9185+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9186+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 48;
9187+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9188+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16 ;
9189+ __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
9190+ __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
9191+ __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
9192+ __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
9193+ __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
9194+ __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
9195+ __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
9196+ __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
9197+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
9198+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
9199+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
9200+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
9201+
9202+ const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(2)));
9203+ const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(6)));
9204+ p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
9205+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
92079206 p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
92089207 p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
92099208
9210- sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
9211- sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
9212- sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
9213- sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
9209+ sumi_2 = _mm_add_epi32(sumi_2, _mm_add_epi32(p16_2, p16_3));
9210+ sumi_3 = _mm_add_epi32(sumi_3, _mm_add_epi32(p16_6, p16_7));
92149211
92159212 }
92169213
9217- __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9218- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
9214+ __m256i sumi1 = MM256_SET_M128I(sumi_0, sumi_1);
9215+ __m256i sumi2 = MM256_SET_M128I(sumi_2, sumi_3);
9216+ acc1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi1)), acc1);
9217+ acc2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi2)), acc2);
92199218 }
92209219
9221- *s = hsum_float_8(acc );
9220+ *s = hsum_float_8(_mm256_add_ps(acc1, acc2) );
92229221
92239222#elif defined __riscv_v_intrinsic
92249223
0 commit comments