@@ -7292,19 +7292,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
72927292 *s = vec_extract(vsumf0, 0);
72937293
72947294#elif defined __loongarch_asx
7295- GGML_UNUSED(kmask1);
7296- GGML_UNUSED(kmask2);
7297- GGML_UNUSED(kmask3);
7298-
7299- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7300- const __m128i mzero = __lsx_vldi(0);
7301- const __m256i mone = __lasx_xvreplgr2vr_b(1);
73027295
73037296 __m256 acc = (__m256)__lasx_xvldi(0);
7297+ __m128 acc_m = (__m128)__lsx_vldi(0);
73047298
7305- float summs = 0.f;
7306-
7307- for (int i = 0; i < nb; ++i) {
7299+ for (int i = 0; i < nb; ++i) {
73087300
73097301 const uint8_t * restrict q5 = x[i].qs;
73107302 const int8_t * restrict q8 = y[i].qs;
@@ -7319,49 +7311,40 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
73197311 utmp[2] = uaux;
73207312 utmp[0] &= kmask1;
73217313
7322- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
7314+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
7315+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
7316+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
73237317
73247318 const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
73257319 const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
7326- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
7327- const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
7328- summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
7320+ const __m128i prod = lsx_madd_h(mins128, q8s);
7321+ acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
73297322
7330- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
7331- const __m256i scales = lasx_insertf128(sc128, sc128);
7323+ const __m256i scales = lasx_insertf128(scales128, scales128);
73327324
73337325 const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
7334- __m256i hmask = mone;
73357326
73367327 __m256i sumi = __lasx_xvldi(0);
73377328
7338- int bit = 0;
7339- __m256i xvbit;
7340-
73417329 for (int j = 0; j < QK_K/64; ++j) {
73427330
7343- const __m256i scale_0 = lasx_shuffle_b (scales, get_scale_shuffle_k4(2*j+0) );
7344- const __m256i scale_1 = lasx_shuffle_b (scales, get_scale_shuffle_k4(2*j+1) );
7331+ const __m256i scale_0 = lasx_xvrepl128vei_h (scales, 2 * j + 0 );
7332+ const __m256i scale_1 = lasx_xvrepl128vei_h (scales, 2 * j + 1 );
73457333
73467334 const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
73477335
7348- xvbit = __lasx_xvreplgr2vr_h(bit++);
7349- const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
7350- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
7351- const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
7352- hmask = __lasx_xvslli_h(hmask, 1);
7353-
7354- xvbit = __lasx_xvreplgr2vr_h(bit++);
7355- const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
7356- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
7357- const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
7358- hmask = __lasx_xvslli_h(hmask, 1);
7336+ const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
7337+ const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
7338+ const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
7339+ const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
7340+ const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
7341+ const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
73597342
73607343 const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
73617344 const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
73627345
7363- __m256i p16_0 = lasx_maddubs_h (q5_0, q8_0);
7364- __m256i p16_1 = lasx_maddubs_h (q5_1, q8_1);
7346+ __m256i p16_0 = lasx_madd_h_b (q5_0, q8_0);
7347+ __m256i p16_1 = lasx_madd_h_b (q5_1, q8_1);
73657348
73667349 p16_0 = lasx_madd_h(scale_0, p16_0);
73677350 p16_1 = lasx_madd_h(scale_1, p16_1);
@@ -7375,7 +7358,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
73757358
73767359 }
73777360
7378- *s = hsum_float_8(acc) + summs;
7361+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
7362+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
7363+
7364+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
73797365
73807366#else
73817367
0 commit comments