@@ -13665,6 +13665,81 @@ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
1366513665 }
1366613666}
1366713667
13668+ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
13669+ GGML_ASSERT(n%32 == 0);
13670+ int32x4_t acc[4] = {};
13671+ auto dptr = (const float *)info.src1_row(0);
13672+ const float dy = dptr[0];
13673+ auto q8y = (const int8_t *)(dptr + 2);
13674+ for (int ix = 0; ix < nrc_x; ++ix) {
13675+ auto dx = (const float *)((const char *)vx + ix*bx);
13676+ auto q8x = (const int8_t *)(dx + 2);
13677+ for (int i = 0; i < n/64; ++i) {
13678+ auto qx = vld1q_s8_x4(q8x + 64*i);
13679+ for (int j = 0; j < 4; ++j) {
13680+ acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j));
13681+ }
13682+ }
13683+ if (int i = 2*(n/64); i < n/32) {
13684+ auto qx = vld1q_s8_x2(q8x + 32*i);
13685+ for (int j = 0; j < 2; ++j) {
13686+ acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j));
13687+ }
13688+ }
13689+ acc[0] = vaddq_s32(acc[0], acc[1]);
13690+ acc[2] = vaddq_s32(acc[2], acc[3]);
13691+ acc[0] = vaddq_s32(acc[0], acc[2]);
13692+ info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0]));
13693+ acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
13694+ }
13695+ }
13696+
13697+ template <int nrc_y>
13698+ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
13699+ GGML_ASSERT(nrc_x%4 == 0);
13700+ GGML_ASSERT(n%16 == 0);
13701+ int8x16_t qx[4];
13702+ int32x4_t acc[nrc_y] = {};
13703+ float dy[nrc_y];
13704+ const int8_t * q8y[nrc_y];
13705+ for (int iy = 0; iy < nrc_y; ++iy) {
13706+ auto dptr = (const float *)info.src1_row(iy);
13707+ dy[iy] = dptr[0];
13708+ q8y[iy] = (const int8_t *)(dptr + 2);
13709+ }
13710+ const int8_t * q8x[4];
13711+ float dx[4];
13712+ for (int ix = 0; ix < nrc_x; ix += 4) {
13713+ for (int kx = 0; kx < 4; ++kx) {
13714+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
13715+ dx[kx] = dptr[0];
13716+ q8x[kx] = (const int8_t *)(dptr + 2);
13717+ }
13718+ for (int i = 0; i < n/16; ++i) {
13719+ for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i);
13720+ auto row01 = vtrnq_s32(qx[0], qx[1]);
13721+ auto row23 = vtrnq_s32(qx[2], qx[3]);
13722+ qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]);
13723+ qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]);
13724+ qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]);
13725+ qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]);
13726+ for (int iy = 0; iy < nrc_y; ++iy) {
13727+ auto y = vld1q_s8(q8y[iy] + 16*i);
13728+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0);
13729+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1);
13730+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2);
13731+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3);
13732+ }
13733+ }
13734+ auto scales_x = vld1q_f32(dx);
13735+ for (int iy = 0; iy < nrc_y; ++iy) {
13736+ auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy]));
13737+ info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy])));
13738+ acc[iy] = vdupq_n_s32(0);
13739+ }
13740+ }
13741+ }
13742+
1366813743void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
1366913744 GGML_ASSERT(nrc_x%4 == 0);
1367013745 Q8<1, block_q8_0_x4> q8(info);
@@ -14241,6 +14316,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
1424114316 SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
1424214317 expected_Btype = GGML_TYPE_Q8_KR8;
1424314318 break;
14319+ case GGML_TYPE_Q8_KV:
14320+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV);
14321+ m.funcs[0] = mul_mat_q8_KV_q8_KV_1;
14322+ m.func16 = mul_mat_q8_KV_q8_KV<16>;
14323+ expected_Btype = GGML_TYPE_Q8_KV;
14324+ break;
1424414325 case GGML_TYPE_IQ2_K_R4:
1424514326 SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
1424614327 expected_Btype = GGML_TYPE_Q8_K;
@@ -14605,9 +14686,8 @@ struct HelperQ8KV final : public BaseHelper<step> {
1460514686 inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
1460614687 auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
1460714688#ifdef __aarch64__
14608- auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
14609- int ii = j%QK8_0;
14610- auto qs = vld1_s8_x2(dl->qs + ii);
14689+ auto vd = F16::set1(q8->d);
14690+ auto qs = vld1_s8_x2(q8->qs + 8*i);
1461114691 v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
1461214692 v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
1461314693#else
@@ -14859,28 +14939,26 @@ struct HelperQ8KVR8 : public BaseHelper<step> {
1485914939 _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3);
1486014940#elif defined __ARM_NEON
1486114941 // TODO
14862- for (int l = 0; l < 2; ++l) {
14863- m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
14864- m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
14865- m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
14866- m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
14867- auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
14868- auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
14869- m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14870- m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14871- m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14872- m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14873- row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
14874- row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
14875- m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14876- m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14877- m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14878- m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14879- vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
14880- vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
14881- vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
14882- vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
14883- }
14942+ m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
14943+ m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
14944+ m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
14945+ m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
14946+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
14947+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
14948+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14949+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14950+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14951+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14952+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
14953+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
14954+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14955+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14956+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14957+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14958+ vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0);
14959+ vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1);
14960+ vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2);
14961+ vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3);
1488414962#else
1488514963 // TODO
1488614964 for (int l = 0; l < 4; ++l) {
0 commit comments