Skip to content

Commit e3a3432

Browse files
committed
better subtract method
1 parent 499e9f2 commit e3a3432

File tree

1 file changed

+49
-66
lines changed

1 file changed

+49
-66
lines changed

ggml/src/ggml-quants.c

Lines changed: 49 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -9107,9 +9107,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91079107
const __m128i m4 = _mm_set1_epi8(0xF);
91089108
const __m128i m3 = _mm_set1_epi8(3);
91099109
const __m128i m32s = _mm_set1_epi8(32);
9110+
const __m128i m2 = _mm_set1_epi8(2);
91109111

9111-
__m256 acc1 = _mm256_setzero_ps();
9112-
__m256 acc2 = _mm256_setzero_ps();
9112+
__m256 acc = _mm256_setzero_ps();
91139113

91149114
for (int i = 0; i < nb; ++i) {
91159115

@@ -9123,8 +9123,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91239123

91249124
__m128i sumi_0 = _mm_setzero_si128();
91259125
__m128i sumi_1 = _mm_setzero_si128();
9126-
__m128i sumi_2 = _mm_setzero_si128();
9127-
__m128i sumi_3 = _mm_setzero_si128();
91289126

91299127
__m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
91309128
for (int j = 0; j < QK_K/128; ++j) {
@@ -9134,90 +9132,75 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
91349132

91359133
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
91369134
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9135+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9136+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
91379137
const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
91389138
const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
9139+
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9140+
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
91399141

91409142
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
91419143
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9144+
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9145+
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9146+
9147+
const __m128i q4_0 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0), m32s);
9148+
const __m128i q4_1 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1), m32s);
9149+
const __m128i q4_2 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2), m32s);
9150+
const __m128i q4_3 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3), m32s);
9151+
const __m128i q4_4 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4), m32s);
9152+
const __m128i q4_5 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5), m32s);
9153+
const __m128i q4_6 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6), m32s);
9154+
const __m128i q4_7 = _mm_sub_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7), m32s);
91429155

9143-
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
9144-
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
9145-
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
9146-
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
91479156
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9148-
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 48;
9157+
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9158+
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9159+
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
91499160
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9150-
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 -= 48;
9151-
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
9152-
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
9153-
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
9154-
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
9155-
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
9156-
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
9157-
__m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
9158-
__m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
9159-
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
9160-
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
9161-
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
9162-
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
9161+
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9162+
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9163+
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9164+
9165+
__m128i p16_0 = mul_add_epi8_sse(q4_0, q8_0);
9166+
__m128i p16_1 = mul_add_epi8_sse(q4_1, q8_1);
9167+
__m128i p16_2 = mul_add_epi8_sse(q4_2, q8_2);
9168+
__m128i p16_3 = mul_add_epi8_sse(q4_3, q8_3);
9169+
__m128i p16_4 = mul_add_epi8_sse(q4_4, q8_4);
9170+
__m128i p16_5 = mul_add_epi8_sse(q4_5, q8_5);
9171+
__m128i p16_6 = mul_add_epi8_sse(q4_6, q8_6);
9172+
__m128i p16_7 = mul_add_epi8_sse(q4_7, q8_7);
91639173

91649174
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
9165-
const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(4)));
9175+
shuffle = _mm_add_epi8(shuffle, m2);
9176+
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
9177+
shuffle = _mm_add_epi8(shuffle, m2);
9178+
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
9179+
shuffle = _mm_add_epi8(shuffle, m2);
9180+
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
9181+
shuffle = _mm_add_epi8(shuffle, m2);
9182+
91669183
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
91679184
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
9168-
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9169-
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
9170-
9171-
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_1));
9172-
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_5));
9173-
9174-
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9175-
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9176-
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9177-
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
9178-
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9179-
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
9180-
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
9181-
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
9182-
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
9183-
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
9184-
9185-
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9186-
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 48;
9187-
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9188-
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16 ;
9189-
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
9190-
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
9191-
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
9192-
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
9193-
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
9194-
__m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
9195-
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
9196-
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
9197-
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
9198-
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
9199-
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
9200-
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
9201-
9202-
const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(2)));
9203-
const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_add_epi8(shuffle, _mm_set1_epi8(6)));
92049185
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
92059186
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
9187+
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9188+
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
92069189
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
92079190
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
92089191

9209-
sumi_2 = _mm_add_epi32(sumi_2, _mm_add_epi32(p16_2, p16_3));
9210-
sumi_3 = _mm_add_epi32(sumi_3, _mm_add_epi32(p16_6, p16_7));
9192+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
9193+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
9194+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
9195+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
92119196

92129197
}
92139198

9214-
__m256i sumi1 = MM256_SET_M128I(sumi_0, sumi_1);
9215-
__m256i sumi2 = MM256_SET_M128I(sumi_2, sumi_3);
9216-
acc1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi1)), acc1);
9217-
acc2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi2)), acc2);
9199+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9200+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
92189201
}
92199202

9220-
*s = hsum_float_8(_mm256_add_ps(acc1, acc2));
9203+
*s = hsum_float_8(acc);
92219204

92229205
#elif defined __riscv_v_intrinsic
92239206

0 commit comments

Comments
 (0)