Skip to content

Commit 59f6ae1

Browse files
committed
Optimize ggml_vec_dot_q3_K_q8_K for LoongArch ASX
1 parent 27e8a23 commit 59f6ae1

File tree

1 file changed

+59
-52
lines changed

1 file changed

+59
-52
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,41 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
562562
return __lasx_xvpickev_b(tmp1, tmp);
563563
}
564564

565+
static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
566+
__m256i tmp1, tmp2;
567+
tmp1 = __lasx_xvmulwev_h_b(a, b);
568+
tmp2 = __lasx_xvmulwod_h_b(a, b);
569+
return __lasx_xvadd_h(tmp1, tmp2);
570+
}
571+
572+
static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
573+
switch (b) {
574+
case 0: return __lasx_xvrepl128vei_h(a, 0);
575+
case 1: return __lasx_xvrepl128vei_h(a, 1);
576+
case 2: return __lasx_xvrepl128vei_h(a, 2);
577+
case 3: return __lasx_xvrepl128vei_h(a, 3);
578+
case 4: return __lasx_xvrepl128vei_h(a, 4);
579+
case 5: return __lasx_xvrepl128vei_h(a, 5);
580+
case 6: return __lasx_xvrepl128vei_h(a, 6);
581+
case 7: return __lasx_xvrepl128vei_h(a, 7);
582+
default: __builtin_unreachable();
583+
}
584+
}
585+
586+
static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
587+
switch (b) {
588+
case 0: return __lasx_xvandi_b(a, 1 << 0);
589+
case 1: return __lasx_xvandi_b(a, 1 << 1);
590+
case 2: return __lasx_xvandi_b(a, 1 << 2);
591+
case 3: return __lasx_xvandi_b(a, 1 << 3);
592+
case 4: return __lasx_xvandi_b(a, 1 << 4);
593+
case 5: return __lasx_xvandi_b(a, 1 << 5);
594+
case 6: return __lasx_xvandi_b(a, 1 << 6);
595+
case 7: return __lasx_xvandi_b(a, 1 << 7);
596+
default: __builtin_unreachable();
597+
}
598+
}
599+
565600
// multiply int8_t, add results pairwise twice
566601
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
567602
// Get absolute values of x vectors
@@ -5771,8 +5806,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
57715806

57725807
#elif defined __loongarch_asx
57735808

5774-
const __m256i m3 = __lasx_xvreplgr2vr_b(3);
5775-
const __m256i mone = __lasx_xvreplgr2vr_b(1);
57765809
const __m128i m32 = __lsx_vreplgr2vr_b(32);
57775810

57785811
__m256 acc = (__m256)__lasx_xvldi(0);
@@ -5792,84 +5825,58 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
57925825
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
57935826
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
57945827
scales128 = __lsx_vsub_b(scales128, m32);
5795-
const __m256i all_scales = lasx_ext8_16(scales128);
5796-
const __m128i l_scales = lasx_extracti128(all_scales, 0);
5797-
const __m128i h_scales = lasx_extracti128(all_scales, 1);
5798-
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
5828+
5829+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
5830+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
57995831

58005832
// high bit
58015833
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
58025834

58035835
// integer accumulator
58045836
__m256i sumi = __lasx_xvldi(0);
58055837

5806-
int bit = 0;
5807-
int is = 0;
5808-
__m256i xvbit;
5809-
5810-
58115838
for (int j = 0; j < QK_K/128; ++j) {
58125839
// load low 2 bits
58135840
const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
58145841

5815-
xvbit = __lasx_xvreplgr2vr_h(bit);
58165842
// prepare low and high bits
5817-
const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
5818-
const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5819-
++bit;
5820-
5821-
xvbit = __lasx_xvreplgr2vr_h(bit);
5822-
const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
5823-
const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5824-
++bit;
5825-
5826-
xvbit = __lasx_xvreplgr2vr_h(bit);
5827-
const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
5828-
const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5829-
++bit;
5830-
5831-
xvbit = __lasx_xvreplgr2vr_h(bit);
5832-
const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
5833-
const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5834-
++bit;
5843+
const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
5844+
const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
5845+
const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
5846+
const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
5847+
const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
5848+
const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
5849+
const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
5850+
const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
5851+
const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
5852+
const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
5853+
const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
5854+
const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
58355855

58365856
// load Q8 quants
58375857
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
58385858
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
58395859
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
58405860
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
58415861

5842-
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
5843-
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
5844-
// and 2 if the high bit was set)
5845-
__m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
5846-
__m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
5847-
__m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
5848-
__m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
5849-
5850-
__m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
5851-
__m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
5852-
__m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
5853-
__m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
5854-
5855-
p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
5856-
p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
5857-
p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
5858-
p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
5862+
__m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
5863+
__m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
5864+
__m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
5865+
__m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
58595866

58605867
// multiply with scales
5861-
p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
5862-
p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
5863-
p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
5864-
p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
5868+
p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
5869+
p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
5870+
p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
5871+
p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
58655872

58665873
// accumulate
58675874
p16_0 = __lasx_xvadd_w(p16_0, p16_1);
58685875
p16_2 = __lasx_xvadd_w(p16_2, p16_3);
58695876
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
58705877
}
58715878
// multiply with block scale and accumulate
5872-
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
5879+
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
58735880
}
58745881

58755882
*s = hsum_float_8(acc);

0 commit comments

Comments
 (0)