@@ -241,13 +241,19 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
241241// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
242242static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
243243 const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
244+ const __m128i mone = _mm_set1_epi16(1);
245+
244246 const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
245247 const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
246248 const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
247249 const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
248- __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
249- __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
250- return sum_i16_pairs_float(p_2, p_1);
250+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
251+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
252+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
253+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
254+ const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
255+ const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
256+ return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
251257}
252258
253259// fp16 delta calculation intended for mul_sum_i8_quad_float
0 commit comments