@@ -8036,8 +8036,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
80368036
80378037#elif defined __loongarch_asx
80388038
8039- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
8040- const __m256i m2 = __lasx_xvreplgr2vr_b(3);
80418039 const __m256i m32s = __lasx_xvreplgr2vr_b(32);
80428040
80438041 __m256 acc = (__m256)__lasx_xvldi(0);
@@ -8050,58 +8048,42 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
80508048 const uint8_t * restrict qh = x[i].qh;
80518049 const int8_t * restrict q8 = y[i].qs;
80528050
8053- const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
8051+ const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
8052+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
8053+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
80548054
80558055 __m256i sumi = __lasx_xvldi(0);
80568056
8057- int is = 0;
8058-
80598057 for (int j = 0; j < QK_K/128; ++j) {
80608058
8061- const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
8062- const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
8063- const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
8064- const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
8065- is += 4;
8066-
80678059 const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
80688060 const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
80698061 const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
80708062
8071- const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v (q4bitsH, m2 ), 4);
8072- const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h( q4bitsH, 2), m2), 4 );
8073- const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h( q4bitsH, 4), m2), 4);
8074- const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h( q4bitsH, 6), m2), 4 );
8063+ const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b (q4bitsH, 3 ), 4);
8064+ const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b( q4bitsH, 3 << 2), 2 );
8065+ const __m256i q4h_2 = __lasx_xvandi_b( q4bitsH, 3 << 4);
8066+ const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b( q4bitsH, 3 << 6), 2 );
80758067
8076- const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v (q4bits1, m4 ), q4h_0);
8077- const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v (q4bits2, m4 ), q4h_1);
8078- const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h( q4bits1, 4), m4 ), q4h_2);
8079- const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h( q4bits2, 4), m4 ), q4h_3);
8068+ const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b (q4bits1, 0xf ), q4h_0);
8069+ const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b (q4bits2, 0xf ), q4h_1);
8070+ const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b( q4bits1, 4), q4h_2);
8071+ const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b( q4bits2, 4), q4h_3);
80808072
80818073 const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
80828074 const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
80838075 const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
80848076 const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
80858077
8086- __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
8087- __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
8088- __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
8089- __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
8090-
8091- __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
8092- __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
8093- __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
8094- __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
8095-
8096- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
8097- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
8098- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
8099- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
8100-
8101- p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
8102- p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
8103- p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
8104- p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
8078+ __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
8079+ __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
8080+ __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
8081+ __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
8082+
8083+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
8084+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
8085+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
8086+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
81058087
81068088 sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
81078089 sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
0 commit comments