@@ -13740,6 +13740,48 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
1374013740 }
1374113741}
1374213742
13743+ template <int nrc_y>
13744+ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
13745+ GGML_ASSERT(nrc_x%8 == 0);
13746+ int32x4_t acc[2*nrc_y] = {};
13747+ float dy[nrc_y];
13748+ const int8_t * q8y[nrc_y];
13749+ for (int iy = 0; iy < nrc_y; ++iy) {
13750+ auto dptr = (const float *)info.src1_row(iy);
13751+ dy[iy] = dptr[0];
13752+ q8y[iy] = (const int8_t *)(dptr + 2);
13753+ }
13754+ for (int ix = 0; ix < nrc_x; ix += 8) {
13755+ const float * dptr = (const float *)((const char *)vx + ix*bx);
13756+ auto q8x = (const int8_t *)(dptr + 8);
13757+ for (int ib = 0; ib < n/16; ++ib) {
13758+ auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
13759+ auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
13760+ for (int iy = 0; iy < nrc_y; ++iy) {
13761+ auto y = vld1q_s8(q8y[iy]+16*ib);
13762+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
13763+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
13764+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
13765+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
13766+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
13767+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
13768+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
13769+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
13770+ }
13771+ }
13772+ auto scale1_x = vld1q_f32(dptr+0);
13773+ auto scale2_x = vld1q_f32(dptr+4);
13774+ for (int iy = 0; iy < nrc_y; ++iy) {
13775+ auto scale_y = vdupq_n_f32(dy[iy]);
13776+ auto scale1 = vmulq_f32(scale1_x, scale_y);
13777+ auto scale2 = vmulq_f32(scale2_x, scale_y);
13778+ info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
13779+ info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
13780+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
13781+ }
13782+ }
13783+ }
13784+
1374313785void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
1374413786 GGML_ASSERT(nrc_x%4 == 0);
1374513787 Q8<1, block_q8_0_x4> q8(info);
@@ -15827,7 +15869,9 @@ struct FlashQKfp32 {
1582715869 }
1582815870 else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
1582915871#ifdef __aarch64__
15830- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
15872+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
15873+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
15874+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
1583115875#else
1583215876#ifdef HAVE_FANCY_SIMD
1583315877 if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
@@ -15844,14 +15888,10 @@ struct FlashQKfp32 {
1584415888#endif
1584515889 }
1584615890 else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
15847- #ifdef __aarch64__
15848- MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
15849- #else
1585015891#ifdef HAVE_FANCY_SIMD
1585115892 if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_r8_q8_KV<16>, 16);
1585215893#endif
1585315894 MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
15854- #endif
1585515895 }
1585615896 else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
1585715897#ifdef __aarch64__
0 commit comments