Skip to content

Commit 7cdb39b

Browse files
committed
Optimize ggml_vec_dot_q5_K_q8_K for LoongArch ASX
1 parent f832891 commit 7cdb39b

File tree

1 file changed

+22
-36
lines changed

1 file changed

+22
-36
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7292,19 +7292,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
72927292
*s = vec_extract(vsumf0, 0);
72937293

72947294
#elif defined __loongarch_asx
7295-
GGML_UNUSED(kmask1);
7296-
GGML_UNUSED(kmask2);
7297-
GGML_UNUSED(kmask3);
7298-
7299-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7300-
const __m128i mzero = __lsx_vldi(0);
7301-
const __m256i mone = __lasx_xvreplgr2vr_b(1);
73027295

73037296
__m256 acc = (__m256)__lasx_xvldi(0);
7297+
__m128 acc_m = (__m128)__lsx_vldi(0);
73047298

7305-
float summs = 0.f;
7306-
7307-
for (int i = 0; i < nb; ++i) {
7299+
for (int i = 0; i < nb; ++i) {
73087300

73097301
const uint8_t * restrict q5 = x[i].qs;
73107302
const int8_t * restrict q8 = y[i].qs;
@@ -7319,49 +7311,40 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
73197311
utmp[2] = uaux;
73207312
utmp[0] &= kmask1;
73217313

7322-
const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
7314+
const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
7315+
const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
7316+
const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
73237317

73247318
const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
73257319
const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
7326-
const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
7327-
const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
7328-
summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
7320+
const __m128i prod = lsx_madd_h(mins128, q8s);
7321+
acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
73297322

7330-
const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
7331-
const __m256i scales = lasx_insertf128(sc128, sc128);
7323+
const __m256i scales = lasx_insertf128(scales128, scales128);
73327324

73337325
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
7334-
__m256i hmask = mone;
73357326

73367327
__m256i sumi = __lasx_xvldi(0);
73377328

7338-
int bit = 0;
7339-
__m256i xvbit;
7340-
73417329
for (int j = 0; j < QK_K/64; ++j) {
73427330

7343-
const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
7344-
const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
7331+
const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
7332+
const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
73457333

73467334
const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
73477335

7348-
xvbit = __lasx_xvreplgr2vr_h(bit++);
7349-
const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
7350-
const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
7351-
const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
7352-
hmask = __lasx_xvslli_h(hmask, 1);
7353-
7354-
xvbit = __lasx_xvreplgr2vr_h(bit++);
7355-
const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
7356-
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
7357-
const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
7358-
hmask = __lasx_xvslli_h(hmask, 1);
7336+
const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
7337+
const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
7338+
const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
7339+
const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
7340+
const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
7341+
const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
73597342

73607343
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
73617344
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
73627345

7363-
__m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
7364-
__m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
7346+
__m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
7347+
__m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
73657348

73667349
p16_0 = lasx_madd_h(scale_0, p16_0);
73677350
p16_1 = lasx_madd_h(scale_1, p16_1);
@@ -7375,7 +7358,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
73757358

73767359
}
73777360

7378-
*s = hsum_float_8(acc) + summs;
7361+
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
7362+
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
7363+
7364+
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
73797365

73807366
#else
73817367

0 commit comments

Comments
 (0)