Skip to content

Commit d2ab776

Browse files
committed
Optimize ggml_vec_dot_q2_K_q8_K for LoongArch ASX
1 parent 7cdb39b commit d2ab776

File tree

1 file changed

+17
-23
lines changed

1 file changed

+17
-23
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5000,9 +5000,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
50005000

50015001
#elif defined __loongarch_asx
50025002

5003-
const __m256i m3 = __lasx_xvreplgr2vr_b(3);
5004-
const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
5005-
50065003
__m256 acc = (__m256)__lasx_xvldi(0);
50075004

50085005
for (int i = 0; i < nb; ++i) {
@@ -5013,18 +5010,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
50135010
const uint8_t * restrict q2 = x[i].qs;
50145011
const int8_t * restrict q8 = y[i].qs;
50155012

5016-
const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
5017-
const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
5018-
const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
5019-
const __m256i mins = lasx_ext8_16(mins8);
5013+
const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
5014+
const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
5015+
const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
50205016
const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
50215017

50225018
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
50235019

5024-
const __m256i all_scales = lasx_ext8_16(scales8);
5025-
const __m128i l_scales = lasx_extracti128(all_scales, 0);
5026-
const __m128i h_scales = lasx_extracti128(all_scales, 1);
5027-
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
5020+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
5021+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
50285022

50295023
__m256i sumi = __lasx_xvldi(0);
50305024

@@ -5037,20 +5031,20 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
50375031
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
50385032
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
50395033

5040-
const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
5041-
const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
5042-
const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
5043-
const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
5034+
const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
5035+
const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
5036+
const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
5037+
const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
50445038

5045-
__m256i p0 = lasx_maddubs_h(q2_0, q8_0);
5046-
__m256i p1 = lasx_maddubs_h(q2_1, q8_1);
5047-
__m256i p2 = lasx_maddubs_h(q2_2, q8_2);
5048-
__m256i p3 = lasx_maddubs_h(q2_3, q8_3);
5039+
__m256i p0 = lasx_madd_h_b(q2_0, q8_0);
5040+
__m256i p1 = lasx_madd_h_b(q2_1, q8_1);
5041+
__m256i p2 = lasx_madd_h_b(q2_2, q8_2);
5042+
__m256i p3 = lasx_madd_h_b(q2_3, q8_3);
50495043

5050-
p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
5051-
p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
5052-
p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
5053-
p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
5044+
p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
5045+
p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
5046+
p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
5047+
p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
50545048

50555049
p0 = __lasx_xvadd_w(p0, p1);
50565050
p2 = __lasx_xvadd_w(p2, p3);

0 commit comments

Comments
 (0)