Skip to content

Commit 1d57041

Browse files
author
Iwan Kawrakow
committed
q8_KV: use it in FA on NEON
1 parent d6ac7a3 commit 1d57041

File tree

1 file changed

+45
-5
lines changed

1 file changed

+45
-5
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13740,6 +13740,48 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
1374013740
}
1374113741
}
1374213742

13743+
template <int nrc_y>
13744+
void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
13745+
GGML_ASSERT(nrc_x%8 == 0);
13746+
int32x4_t acc[2*nrc_y] = {};
13747+
float dy[nrc_y];
13748+
const int8_t * q8y[nrc_y];
13749+
for (int iy = 0; iy < nrc_y; ++iy) {
13750+
auto dptr = (const float *)info.src1_row(iy);
13751+
dy[iy] = dptr[0];
13752+
q8y[iy] = (const int8_t *)(dptr + 2);
13753+
}
13754+
for (int ix = 0; ix < nrc_x; ix += 8) {
13755+
const float * dptr = (const float *)((const char *)vx + ix*bx);
13756+
auto q8x = (const int8_t *)(dptr + 8);
13757+
for (int ib = 0; ib < n/16; ++ib) {
13758+
auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
13759+
auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
13760+
for (int iy = 0; iy < nrc_y; ++iy) {
13761+
auto y = vld1q_s8(q8y[iy]+16*ib);
13762+
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
13763+
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
13764+
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
13765+
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
13766+
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
13767+
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
13768+
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
13769+
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
13770+
}
13771+
}
13772+
auto scale1_x = vld1q_f32(dptr+0);
13773+
auto scale2_x = vld1q_f32(dptr+4);
13774+
for (int iy = 0; iy < nrc_y; ++iy) {
13775+
auto scale_y = vdupq_n_f32(dy[iy]);
13776+
auto scale1 = vmulq_f32(scale1_x, scale_y);
13777+
auto scale2 = vmulq_f32(scale2_x, scale_y);
13778+
info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
13779+
info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
13780+
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
13781+
}
13782+
}
13783+
}
13784+
1374313785
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
1374413786
GGML_ASSERT(nrc_x%4 == 0);
1374513787
Q8<1, block_q8_0_x4> q8(info);
@@ -15827,7 +15869,9 @@ struct FlashQKfp32 {
1582715869
}
1582815870
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
1582915871
#ifdef __aarch64__
15830-
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
15872+
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
15873+
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
15874+
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
1583115875
#else
1583215876
#ifdef HAVE_FANCY_SIMD
1583315877
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
@@ -15844,14 +15888,10 @@ struct FlashQKfp32 {
1584415888
#endif
1584515889
}
1584615890
else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
15847-
#ifdef __aarch64__
15848-
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
15849-
#else
1585015891
#ifdef HAVE_FANCY_SIMD
1585115892
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_r8_q8_KV<16>, 16);
1585215893
#endif
1585315894
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
15854-
#endif
1585515895
}
1585615896
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
1585715897
#ifdef __aarch64__

0 commit comments

Comments
 (0)