@@ -5000,9 +5000,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
50005000
50015001#elif defined __loongarch_asx
50025002
5003- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
5004- const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
5005-
50065003 __m256 acc = (__m256)__lasx_xvldi(0);
50075004
50085005 for (int i = 0; i < nb; ++i) {
@@ -5013,18 +5010,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
50135010 const uint8_t * restrict q2 = x[i].qs;
50145011 const int8_t * restrict q8 = y[i].qs;
50155012
5016- const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
5017- const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
5018- const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
5019- const __m256i mins = lasx_ext8_16(mins8);
5013+ const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
5014+ const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
5015+ const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
50205016 const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
50215017
50225018 acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
50235019
5024- const __m256i all_scales = lasx_ext8_16(scales8);
5025- const __m128i l_scales = lasx_extracti128(all_scales, 0);
5026- const __m128i h_scales = lasx_extracti128(all_scales, 1);
5027- const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
5020+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
5021+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
50285022
50295023 __m256i sumi = __lasx_xvldi(0);
50305024
@@ -5037,20 +5031,20 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
50375031 const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
50385032 const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
50395033
5040- const __m256i q2_0 = __lasx_xvand_v (q2bits, m3 );
5041- const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h (q2bits, 2), m3 );
5042- const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h (q2bits, 4), m3 );
5043- const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h( q2bits, 6), m3 );
5034+ const __m256i q2_0 = __lasx_xvandi_b (q2bits, 3 );
5035+ const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b (q2bits, 2), 3 );
5036+ const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b (q2bits, 4), 3 );
5037+ const __m256i q2_3 = __lasx_xvsrli_b( q2bits, 6);
50445038
5045- __m256i p0 = lasx_maddubs_h (q2_0, q8_0);
5046- __m256i p1 = lasx_maddubs_h (q2_1, q8_1);
5047- __m256i p2 = lasx_maddubs_h (q2_2, q8_2);
5048- __m256i p3 = lasx_maddubs_h (q2_3, q8_3);
5039+ __m256i p0 = lasx_madd_h_b (q2_0, q8_0);
5040+ __m256i p1 = lasx_madd_h_b (q2_1, q8_1);
5041+ __m256i p2 = lasx_madd_h_b (q2_2, q8_2);
5042+ __m256i p3 = lasx_madd_h_b (q2_3, q8_3);
50495043
5050- p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0) ), p0);
5051- p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1) ), p1);
5052- p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2) ), p2);
5053- p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3) ), p3);
5044+ p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0 ), p0);
5045+ p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1 ), p1);
5046+ p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2 ), p2);
5047+ p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3 ), p3);
50545048
50555049 p0 = __lasx_xvadd_w(p0, p1);
50565050 p2 = __lasx_xvadd_w(p2, p3);
0 commit comments