Skip to content

Commit 66ed5e3

Browse files
committed
Optimize ggml_vec_dot_iq4_xs_q8_K for LoongArch ASX
1 parent 1c27b04 commit 66ed5e3

File tree

1 file changed

+11
-51
lines changed

1 file changed

+11
-51
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 11 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10383,13 +10383,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
1038310383
}
1038410384
#elif defined(__loongarch_asx)
1038510385
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
10386-
const __m256i ax = __lasx_xvsigncov_b(x, x);
10387-
const __m256i sy = __lasx_xvsigncov_b(x, y);
10388-
__m256i tmp1, tmp2, tmp3;
10389-
tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
10390-
tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
10391-
tmp3 = __lasx_xvadd_h(tmp1, tmp2);
10392-
return __lasx_xvsat_h(tmp3, 15);
10386+
const __m256i a = __lasx_xvmulwev_h_b(x, y);
10387+
const __m256i b = __lasx_xvmulwod_h_b(x, y);
10388+
return __lasx_xvadd_h(a, b);
1039310389
}
1039410390
#endif
1039510391

@@ -11439,67 +11435,31 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
1143911435
#elif defined(__loongarch_asx)
1144011436

1144111437
const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
11442-
const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
1144311438

1144411439
__m256 accum = (__m256)__lasx_xvldi(0);
11445-
__m256i tmp1;
11446-
__m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
1144711440

11448-
mask_8f = __lsx_vreplgr2vr_b(0x8f);
1144911441
for (int ibl = 0; ibl < nb; ++ibl) {
1145011442
const uint8_t * qs = x[ibl].qs;
1145111443
const int8_t * q8 = y[ibl].qs;
1145211444
uint16_t sh = x[ibl].scales_h;
1145311445
__m256i sumi1 = __lasx_xvldi(0);
1145411446
__m256i sumi2 = __lasx_xvldi(0);
11455-
__m128i zero = __lsx_vldi(0);
1145611447
for (int ib = 0; ib < QK_K/32; ib += 2) {
11457-
const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
11458-
const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
11448+
const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
11449+
const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
1145911450
const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1146011451
const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
11461-
tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
11462-
tmp0 = __lsx_vori_b(tmp2, 0x10);
11463-
mask = __lsx_vsle_b(zero, tmp2);
11464-
tmp3 = __lsx_vand_v(tmp0, mask);
11465-
tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
11466-
11467-
tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
11468-
tmp0 = __lsx_vori_b(tmp2, 0x10);
11469-
mask = __lsx_vsle_b(zero, tmp2);
11470-
tmp4 = __lsx_vand_v(tmp0, mask);
11471-
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
11472-
11473-
const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
11474-
11475-
tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
11476-
tmp0 = __lsx_vori_b(tmp2, 0x10);
11477-
mask = __lsx_vsle_b(zero, tmp2);
11478-
tmp3 = __lsx_vand_v(tmp0, mask);
11479-
tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
11480-
11481-
tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
11482-
tmp0 = __lsx_vori_b(tmp2, 0x10);
11483-
mask = __lsx_vsle_b(zero, tmp2);
11484-
tmp4 = __lsx_vand_v(tmp0, mask);
11485-
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
11486-
11487-
const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
11488-
11452+
const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
11453+
__lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
11454+
const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
11455+
__lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
1148911456
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
1149011457
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
1149111458
const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
1149211459
const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
1149311460
sh >>= 4;
11494-
__m256i tmp5, tmp6;
11495-
tmp1 = __lasx_xvreplgr2vr_h(ls1);
11496-
tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
11497-
tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
11498-
const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
11499-
tmp1 = __lasx_xvreplgr2vr_h(ls2);
11500-
tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
11501-
tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
11502-
const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
11461+
const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
11462+
const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
1150311463
sumi1 = __lasx_xvadd_w(p_1, sumi1);
1150411464
sumi2 = __lasx_xvadd_w(p_2, sumi2);
1150511465
}

0 commit comments

Comments
 (0)