Skip to content

Commit f832891

Browse files
committed
Optimize ggml_vec_dot_q6_K_q8_K for LoongArch ASX
1 parent a34fc87 commit f832891

File tree

1 file changed

+20
-38
lines changed

1 file changed

+20
-38
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8036,8 +8036,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
80368036

80378037
#elif defined __loongarch_asx
80388038

8039-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
8040-
const __m256i m2 = __lasx_xvreplgr2vr_b(3);
80418039
const __m256i m32s = __lasx_xvreplgr2vr_b(32);
80428040

80438041
__m256 acc = (__m256)__lasx_xvldi(0);
@@ -8050,58 +8048,42 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
80508048
const uint8_t * restrict qh = x[i].qh;
80518049
const int8_t * restrict q8 = y[i].qs;
80528050

8053-
const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
8051+
const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
8052+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
8053+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
80548054

80558055
__m256i sumi = __lasx_xvldi(0);
80568056

8057-
int is = 0;
8058-
80598057
for (int j = 0; j < QK_K/128; ++j) {
80608058

8061-
const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
8062-
const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
8063-
const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
8064-
const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
8065-
is += 4;
8066-
80678059
const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
80688060
const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
80698061
const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
80708062

8071-
const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
8072-
const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
8073-
const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
8074-
const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
8063+
const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
8064+
const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
8065+
const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
8066+
const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
80758067

8076-
const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
8077-
const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
8078-
const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
8079-
const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
8068+
const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
8069+
const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
8070+
const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
8071+
const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
80808072

80818073
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
80828074
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
80838075
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
80848076
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
80858077

8086-
__m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
8087-
__m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
8088-
__m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
8089-
__m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
8090-
8091-
__m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
8092-
__m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
8093-
__m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
8094-
__m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
8095-
8096-
p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
8097-
p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
8098-
p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
8099-
p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
8100-
8101-
p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
8102-
p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
8103-
p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
8104-
p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
8078+
__m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
8079+
__m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
8080+
__m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
8081+
__m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
8082+
8083+
p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
8084+
p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
8085+
p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
8086+
p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
81058087

81068088
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
81078089
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));

0 commit comments

Comments
 (0)