@@ -6173,7 +6173,7 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
61736173
61746174// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
61756175template <int nrc_y>
6176- static void mul_mat_q8_KV_r8_q8_k (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
6176+ static void mul_mat_q8_KV_r8_q8_KV (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
61776177 GGML_ASSERT(nrc_x%8 == 0);
61786178 GGML_ASSERT(n%32 == 0);
61796179#ifndef HAVE_FANCY_SIMD
@@ -6192,15 +6192,15 @@ static void mul_mat_q8_KV_r8_q8_k(int n, const void * vx, size_t bx, const DataI
61926192 dy[iy] = dptr[0];
61936193#ifdef HAVE_FANCY_SIMD
61946194 auto iptr = (const int32_t *)(dptr + 1);
6195- sy[iy] = -128 *iptr[0];
6195+ sy[iy] = -127 *iptr[0];
61966196#endif
61976197 q8y[iy] = (const int8_t *)(dptr + 2);
61986198 }
61996199 for (int ix = 0; ix < nrc_x; ix += 8) {
62006200 auto dptr = (const float *)((const char *)vx + ix*bx);
62016201 auto dx = _mm256_loadu_ps(dptr);
62026202 auto q8x = (const int8_t *)(dptr + 8);
6203- for (int ib = 0; ib < nb; ++ib) { // Blocks of 32
6203+ for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
62046204 qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
62056205 qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
62066206 qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
@@ -6229,18 +6229,14 @@ static void mul_mat_q8_KV_r8_q8_k(int n, const void * vx, size_t bx, const DataI
62296229 acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
62306230#endif
62316231 }
6232- for (int iy = 0; iy < nrc_y; ++iy) {
6233- auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
6234- #ifdef HAVE_FANCY_SIMD
6235- acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
6236- #endif
6237- info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
6238- acc[iy] = _mm256_setzero_si256();
6239- }
62406232 }
62416233 for (int iy = 0; iy < nrc_y; ++iy) {
6242- info.store(ix, iy, acc[iy]);
6243- acc[iy] = _mm256_setzero_ps();
6234+ auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
6235+ #ifdef HAVE_FANCY_SIMD
6236+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
6237+ #endif
6238+ info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
6239+ acc[iy] = _mm256_setzero_si256();
62446240 }
62456241 }
62466242}
@@ -14802,6 +14798,104 @@ struct HelperQ80R8 : public BaseHelper<step> {
1480214798 std::vector<block_q8_0_r8> r4;
1480314799};
1480414800
14801+ // TODO: unite this with the above
14802+ template <int D, int step>
14803+ struct HelperQ8KVR8 : public BaseHelper<step> {
14804+ using Base = BaseHelper<step>;
14805+ constexpr static int block_size_q = D;
14806+ using block_q8 = block_q8_KV<D>;
14807+
14808+ struct block_q8_KV_r8 {
14809+ float d[8];
14810+ int8_t qs[8*D];
14811+ };
14812+
14813+ HelperQ8KVR8(int nk, const HelperQ8KV<D, step>& q8) : Base(q8.data, q8.stride) {
14814+ r4 = repack(nk, q8);
14815+ Base::data = (const char *)r4.data();
14816+ Base::stride = sizeof(block_q8_KV_r8)/8;
14817+ }
14818+
14819+ static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D, step>& q8) {
14820+ static_assert(D%32 == 0);
14821+ GGML_ASSERT(nk%8 == 0);
14822+ std::vector<block_q8_KV_r8> result(nk/8);
14823+ auto y = result.data();
14824+ #ifdef __ARM_NEON
14825+ int8x16x2_t m0, m1, m2, m3;
14826+ #endif
14827+ const int8_t * x8[8];
14828+ for (int ix = 0; ix < nk/8; ++ix) {
14829+ for (int k = 0; k < 8; ++k) {
14830+ auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride);
14831+ y[ix].d[k] = dptr[0];
14832+ x8[k] = (const int8_t *)(dptr + 2);
14833+ }
14834+ for (int ib = 0; ib < D/16; ++ib) {
14835+ #ifdef __AVX2__
14836+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
14837+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
14838+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
14839+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
14840+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
14841+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
14842+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
14843+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
14844+ m0 = _mm256_unpacklo_epi64(t0, t1);
14845+ m1 = _mm256_unpackhi_epi64(t0, t1);
14846+ m2 = _mm256_unpacklo_epi64(t2, t3);
14847+ m3 = _mm256_unpackhi_epi64(t2, t3);
14848+ #ifdef HAVE_FANCY_SIMD
14849+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
14850+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
14851+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
14852+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
14853+ #endif
14854+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0);
14855+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1);
14856+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2);
14857+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3);
14858+ #elif defined __ARM_NEON
14859+ // TODO
14860+ for (int l = 0; l < 2; ++l) {
14861+ m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
14862+ m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
14863+ m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
14864+ m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
14865+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
14866+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
14867+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14868+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14869+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14870+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14871+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
14872+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
14873+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14874+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14875+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
14876+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
14877+ vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
14878+ vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
14879+ vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
14880+ vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
14881+ }
14882+ #else
14883+ // TODO
14884+ for (int l = 0; l < 4; ++l) {
14885+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
14886+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
14887+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
14888+ }
14889+ }
14890+ #endif
14891+ }
14892+ }
14893+ return result;
14894+ }
14895+
14896+ std::vector<block_q8_KV_r8> r4;
14897+ };
14898+
1480514899template <int D, int step>
1480614900struct HelperQ40 final : public BaseHelper<step> {
1480714901 using Base = BaseHelper<step>;
@@ -15365,9 +15459,9 @@ struct FlashQKV {
1536515459 }
1536615460
1536715461 inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
15368- GGML_ASSERT(fms.S[j] > 0);
15369- auto norm = F16::set1(1/fms.S[j]);
15370- // auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
15462+ // GGML_ASSERT(fms.S[j] > 0);
15463+ // auto norm = F16::set1(1/fms.S[j]);
15464+ auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
1537115465 for (int i = 0; i < D/F16::block_size; ++i) {
1537215466 auto r = F16::load(R + F16::block_size*i);
1537315467 F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
@@ -15658,6 +15752,7 @@ struct FlashQKfp32 {
1565815752#ifdef HAVE_FANCY_SIMD
1565915753 if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
1566015754#endif
15755+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
1566115756 MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
1566215757#endif
1566315758 }
@@ -15666,6 +15761,16 @@ struct FlashQKfp32 {
1566615761 MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
1566715762#else
1566815763 MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq);
15764+ #endif
15765+ }
15766+ else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
15767+ #ifdef __aarch64__
15768+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
15769+ #else
15770+ #ifdef HAVE_FANCY_SIMD
15771+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_r8_q8_KV<16>, 16);
15772+ #endif
15773+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
1566915774#endif
1567015775 }
1567115776 else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
@@ -15908,7 +16013,7 @@ struct FlashAttn {
1590816013 void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1590916014 const float * q, const char * mask, float * qkv) {
1591016015 if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
15911- std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> || std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
16016+ std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
1591216017 std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) {
1591316018 compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
1591416019 kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
@@ -15928,6 +16033,22 @@ struct FlashAttn {
1592816033 compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
1592916034 kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
1593016035 }
16036+ }
16037+ else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
16038+ if (nq1 >= 8) {
16039+ #if FA_TIMING
16040+ auto t1 = Perf::cur_time();
16041+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
16042+ Perf::instance().accum(4, t1);
16043+ #else
16044+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
16045+ #endif
16046+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
16047+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
16048+ } else{
16049+ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
16050+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
16051+ }
1593116052 } else {
1593216053 compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
1593316054 kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
0 commit comments