Skip to content

Commit 35255d6

Browse files
committed
handle -32 offset separately. bsums exists for a reason!
1 parent a420e4c commit 35255d6

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

ggml/src/ggml-quants.c

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9104,10 +9104,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91049104

91059105
#elif defined __AVX__
91069106

9107-
const __m128i m2 = _mm_set1_epi8(2);
91089107
const __m128i m3 = _mm_set1_epi8(3);
91099108
const __m128i m15 = _mm_set1_epi8(15);
9110-
const __m128i m32 = _mm_set1_epi8(32);
91119109

91129110
__m256 acc = _mm256_setzero_ps();
91139111

@@ -9119,7 +9117,15 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91199117
const uint8_t * restrict qh = x[i].qh;
91209118
const int8_t * restrict q8 = y[i].qs;
91219119

9120+
// handle the q6_k -32 offset separately using bsums
9121+
// TODO: tabs, compiler warnings for earlier code
9122+
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
9123+
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
91229124
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
9125+
const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
9126+
const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
9127+
const __m128i q8scld_0 = _mm_mullo_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), _mm_set1_epi32(32));
9128+
const __m128i q8scld_1 = _mm_mullo_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), _mm_set1_epi32(32));
91239129

91249130
__m128i sumi_0 = _mm_setzero_si128();
91259131
__m128i sumi_1 = _mm_setzero_si128();
@@ -9145,14 +9151,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91459151
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91469152
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91479153

9148-
const __m128i q4_0 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0), m32);
9149-
const __m128i q4_1 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1), m32);
9150-
const __m128i q4_2 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2), m32);
9151-
const __m128i q4_3 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3), m32);
9152-
const __m128i q4_4 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4), m32);
9153-
const __m128i q4_5 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5), m32);
9154-
const __m128i q4_6 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6), m32);
9155-
const __m128i q4_7 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7), m32);
9154+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
9155+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
9156+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
9157+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
9158+
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
9159+
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
9160+
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
9161+
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
91569162

91579163
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91589164
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@@ -9163,14 +9169,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91639169
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91649170
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91659171

9166-
__m128i p16_0 = mul_add_epi8_sse(q4_0, q8_0);
9167-
__m128i p16_1 = mul_add_epi8_sse(q4_1, q8_1);
9168-
__m128i p16_2 = mul_add_epi8_sse(q4_2, q8_2);
9169-
__m128i p16_3 = mul_add_epi8_sse(q4_3, q8_3);
9170-
__m128i p16_4 = mul_add_epi8_sse(q4_4, q8_4);
9171-
__m128i p16_5 = mul_add_epi8_sse(q4_5, q8_5);
9172-
__m128i p16_6 = mul_add_epi8_sse(q4_6, q8_6);
9173-
__m128i p16_7 = mul_add_epi8_sse(q4_7, q8_7);
9172+
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
9173+
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
9174+
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
9175+
__m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
9176+
__m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
9177+
__m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
9178+
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
9179+
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
91749180

91759181
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
91769182
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
@@ -9191,10 +9197,11 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91919197
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
91929198
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
91939199
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
9194-
91959200
}
91969201

9197-
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9202+
sumi_0 = _mm_sub_epi32(sumi_0, q8scld_0);
9203+
sumi_1 = _mm_sub_epi32(sumi_1, q8scld_1);
9204+
const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
91989205
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
91999206
}
92009207

0 commit comments

Comments
 (0)