@@ -6569,11 +6569,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
65696569 *s = vec_extract(vsumf0, 0);
65706570
65716571#elif defined __loongarch_asx
6572- GGML_UNUSED(kmask1);
6573- GGML_UNUSED(kmask2);
6574- GGML_UNUSED(kmask3);
6575-
6576- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
65776572
65786573 __m256 acc = (__m256)__lasx_xvldi(0);
65796574 __m128 acc_m = (__m128)__lsx_vldi(0);
@@ -6593,33 +6588,34 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
65936588 const uint8_t * restrict q4 = x[i].qs;
65946589 const int8_t * restrict q8 = y[i].qs;
65956590
6596- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
6591+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
6592+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
6593+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
65976594
65986595 const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
65996596 const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
6600- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1) , q8s);
6597+ const __m128i prod = lsx_madd_h(mins128 , q8s);
66016598 acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
66026599
6603- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
6604- const __m256i scales = lasx_insertf128(sc128, sc128);
6600+ const __m256i scales = lasx_insertf128(scales128, scales128);
66056601
66066602 __m256i sumi = __lasx_xvldi(0);
66076603
66086604 for (int j = 0; j < QK_K/64; ++j) {
66096605
6610- const __m256i scale_l = lasx_shuffle_b (scales, get_scale_shuffle_k4(2*j+0) );
6611- const __m256i scale_h = lasx_shuffle_b (scales, get_scale_shuffle_k4(2*j+1) );
6606+ const __m256i scale_l = lasx_xvrepl128vei_h (scales, 2 * j + 0 );
6607+ const __m256i scale_h = lasx_xvrepl128vei_h (scales, 2 * j + 1 );
66126608
66136609 const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
6614- const __m256i q4l = __lasx_xvand_v (q4bits, m4 );
6615- const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h( q4bits, 4), m4 );
6610+ const __m256i q4l = __lasx_xvandi_b (q4bits, 0xf );
6611+ const __m256i q4h = __lasx_xvsrli_b( q4bits, 4);
66166612
66176613 const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6618- __m256i p16l = lasx_maddubs_h (q4l, q8l);
6614+ __m256i p16l = lasx_madd_h_b (q4l, q8l);
66196615 p16l = lasx_madd_h(scale_l, p16l);
66206616
66216617 const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6622- __m256i p16h = lasx_maddubs_h (q4h, q8h);
6618+ __m256i p16h = lasx_madd_h_b (q4h, q8h);
66236619 p16h = lasx_madd_h(scale_h, p16h);
66246620 const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
66256621
0 commit comments