Skip to content

Commit a34fc87

Browse files
committed
Optimize ggml_vec_dot_q4_K_q8_K for LoongArch ASX
1 parent 59f6ae1 commit a34fc87

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6569,11 +6569,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
65696569
*s = vec_extract(vsumf0, 0);
65706570

65716571
#elif defined __loongarch_asx
6572-
GGML_UNUSED(kmask1);
6573-
GGML_UNUSED(kmask2);
6574-
GGML_UNUSED(kmask3);
6575-
6576-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
65776572

65786573
__m256 acc = (__m256)__lasx_xvldi(0);
65796574
__m128 acc_m = (__m128)__lsx_vldi(0);
@@ -6593,33 +6588,34 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
65936588
const uint8_t * restrict q4 = x[i].qs;
65946589
const int8_t * restrict q8 = y[i].qs;
65956590

6596-
const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
6591+
const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
6592+
const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
6593+
const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
65976594

65986595
const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
65996596
const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
6600-
const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
6597+
const __m128i prod = lsx_madd_h(mins128, q8s);
66016598
acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
66026599

6603-
const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
6604-
const __m256i scales = lasx_insertf128(sc128, sc128);
6600+
const __m256i scales = lasx_insertf128(scales128, scales128);
66056601

66066602
__m256i sumi = __lasx_xvldi(0);
66076603

66086604
for (int j = 0; j < QK_K/64; ++j) {
66096605

6610-
const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
6611-
const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
6606+
const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
6607+
const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
66126608

66136609
const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
6614-
const __m256i q4l = __lasx_xvand_v(q4bits, m4);
6615-
const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
6610+
const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
6611+
const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
66166612

66176613
const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6618-
__m256i p16l = lasx_maddubs_h(q4l, q8l);
6614+
__m256i p16l = lasx_madd_h_b(q4l, q8l);
66196615
p16l = lasx_madd_h(scale_l, p16l);
66206616

66216617
const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6622-
__m256i p16h = lasx_maddubs_h(q4h, q8h);
6618+
__m256i p16h = lasx_madd_h_b(q4h, q8h);
66236619
p16h = lasx_madd_h(scale_h, p16h);
66246620
const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
66256621

0 commit comments

Comments
 (0)