@@ -6171,6 +6171,80 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
61716171 }
61726172}
61736173
6174+ // The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
6175+ template <int nrc_y>
6176+ static void mul_mat_q8_KV_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
6177+ GGML_ASSERT(nrc_x%8 == 0);
6178+ GGML_ASSERT(n%32 == 0);
6179+ #ifndef HAVE_FANCY_SIMD
6180+ auto m1 = _mm256_set1_epi16(1);
6181+ #endif
6182+ int nb = n / 16;
6183+ __m256i acc[nrc_y] = {};
6184+ __m256i qx[4];
6185+ float dy[nrc_y];
6186+ #ifdef HAVE_FANCY_SIMD
6187+ float sy[nrc_y];
6188+ #endif
6189+ const int8_t * q8y[nrc_y];
6190+ for (int iy = 0; iy < nrc_y; ++iy) {
6191+ auto dptr = (const float *)info.src1_row(iy);
6192+ dy[iy] = dptr[0];
6193+ #ifdef HAVE_FANCY_SIMD
6194+ auto iptr = (const int32_t *)(dptr + 1);
6195+ sy[iy] = -128*iptr[0];
6196+ #endif
6197+ q8y[iy] = (const int8_t *)(dptr + 2);
6198+ }
6199+ for (int ix = 0; ix < nrc_x; ix += 8) {
6200+ auto dptr = (const float *)((const char *)vx + ix*bx);
6201+ auto dx = _mm256_loadu_ps(dptr);
6202+ auto q8x = (const int8_t *)(dptr + 8);
6203+ for (int ib = 0; ib < nb; ++ib) { // Blocks of 32
6204+ qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
6205+ qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
6206+ qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
6207+ qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
6208+ #ifndef HAVE_FANCY_SIMD
6209+ auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
6210+ auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
6211+ auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
6212+ auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
6213+ #endif
6214+ for (int iy = 0; iy < nrc_y; ++iy) {
6215+ auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
6216+ auto y = MM256_SET_M128I(y128, y128);
6217+ #ifdef HAVE_FANCY_SIMD
6218+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
6219+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
6220+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
6221+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
6222+ #else
6223+ auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6224+ auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6225+ auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6226+ auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6227+ auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
6228+ auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
6229+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
6230+ #endif
6231+ }
6232+ for (int iy = 0; iy < nrc_y; ++iy) {
6233+ auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
6234+ #ifdef HAVE_FANCY_SIMD
6235+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
6236+ #endif
6237+ info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
6238+ acc[iy] = _mm256_setzero_si256();
6239+ }
6240+ }
6241+ for (int iy = 0; iy < nrc_y; ++iy) {
6242+ info.store(ix, iy, acc[iy]);
6243+ acc[iy] = _mm256_setzero_ps();
6244+ }
6245+ }
6246+ }
6247+
61746248template <int nrc_y>
61756249static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
61766250 GGML_ASSERT(nrc_x%8 == 0);
@@ -14516,13 +14590,50 @@ struct HelperF16 final : public BaseHelper<step> {
1451614590 }
1451714591};
1451814592
14593+ template <int D> struct block_q8_KV {
14594+ float d;
14595+ int s;
14596+ int8_t qs[D];
14597+ };
14598+
14599+ template <int D, int step>
14600+ struct HelperQ8KV final : public BaseHelper<step> {
14601+ using Base = BaseHelper<step>;
14602+ using block_q8 = block_q8_KV<D>;
14603+ constexpr static int block_size_q = D;
14604+ HelperQ8KV(const char * data, int stride) : Base(data, stride) {}
14605+
14606+ // Needed for v * softmax(k * q)
14607+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
14608+ auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
14609+ #ifdef __aarch64__
14610+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
14611+ int ii = j%QK8_0;
14612+ auto qs = vld1_s8_x2(dl->qs + ii);
14613+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
14614+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
14615+ #else
14616+ auto vd = F16::set1(q8->d);
14617+ #ifdef HAVE_FANCY_SIMD
14618+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0))));
14619+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1))));
14620+ #else
14621+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0)))));
14622+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8)))));
14623+ #endif
14624+ #endif
14625+ }
14626+ };
14627+
1451914628template <int D, int step>
1452014629struct HelperQ80 final : public BaseHelper<step> {
1452114630 using Base = BaseHelper<step>;
1452214631#ifdef HAVE_FANCY_SIMD
1452314632 using block_q8 = block_q8_1;
14633+ constexpr static int block_size_q = QK8_1;
1452414634#else
1452514635 using block_q8 = block_q8_0;
14636+ constexpr static int block_size_q = QK8_0;
1452614637#endif
1452714638 HelperQ80(const char * data, int stride) : Base(data, stride) {}
1452814639
@@ -14566,23 +14677,33 @@ struct HelperQ80 final : public BaseHelper<step> {
1456614677 y += D/QK8_1;
1456714678 }
1456814679 }
14680+
14681+ static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) {
14682+ for (int i = 0; i < nq; ++i) {
14683+ quantize_row_q8_KV(q, y, D);
14684+ q += stride_q;
14685+ ++y;
14686+ }
14687+ }
1456914688};
1457014689
1457114690template <int D, int step>
14572- struct HelperQ80R4 : public BaseHelper<step> {
14691+ struct HelperQ80R8 : public BaseHelper<step> {
1457314692 using Base = BaseHelper<step>;
1457414693#ifdef __AVX2__
14694+ constexpr static int block_size_q = QK8_1;
1457514695 using block_q8 = block_q8_1;
1457614696#else
14697+ constexpr static int block_size_q = QK8_0;
1457714698 using block_q8 = block_q8_0;
1457814699#endif
14579- HelperQ80R4 (int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
14700+ HelperQ80R8 (int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
1458014701 r4 = repack(nk, q8);
1458114702 Base::data = (const char *)r4.data();
1458214703 Base::stride = (D/QK8_0)*sizeof(block_q8_0);
1458314704 }
1458414705
14585- static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step> q8) {
14706+ static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
1458614707 static_assert(D%QK8_0 == 0);
1458714708 GGML_ASSERT(nk%8 == 0);
1458814709 constexpr int nblock = D/QK8_0;
@@ -14685,6 +14806,7 @@ template <int D, int step>
1468514806struct HelperQ40 final : public BaseHelper<step> {
1468614807 using Base = BaseHelper<step>;
1468714808 using block_q8 = block_q8_0;
14809+ constexpr static int block_size_q = QK8_0;
1468814810 HelperQ40(const char * data, int stride) : Base(data, stride) {}
1468914811
1469014812 // Needed for v * softmax(k * q)
@@ -14728,6 +14850,7 @@ template <int D, int step>
1472814850struct HelperQ41 final : public BaseHelper<step> {
1472914851 using Base = BaseHelper<step>;
1473014852 using block_q8 = block_q8_1;
14853+ constexpr static int block_size_q = QK8_1;
1473114854 HelperQ41(const char * data, int stride) : Base(data, stride) {}
1473214855
1473314856 // Needed for v * softmax(k * q)
@@ -14818,8 +14941,10 @@ template <int D, int step>
1481814941struct HelperQ60 final : public BaseHelper<step> {
1481914942#ifdef __aarch64__
1482014943 using block_q8 = block_q8_0;
14944+ constexpr static int block_size_q = QK8_0;
1482114945#else
1482214946 using block_q8 = block_q8_1;
14947+ constexpr static int block_size_q = QK8_1;
1482314948#endif
1482414949 using Base = BaseHelper<step>;
1482514950 HelperQ60(const char * data, int stride) : Base(data, stride) {}
@@ -15526,7 +15651,17 @@ struct FlashQKfp32 {
1552615651#endif
1552715652#endif
1552815653 }
15529- else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
15654+ else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
15655+ #ifdef __aarch64__
15656+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
15657+ #else
15658+ #ifdef HAVE_FANCY_SIMD
15659+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
15660+ #endif
15661+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
15662+ #endif
15663+ }
15664+ else if constexpr (std::is_same_v<KHelper, HelperQ80R8<D, k_step>>) {
1553015665#ifdef __aarch64__
1553115666 MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
1553215667#else
@@ -15575,7 +15710,7 @@ struct FlashQKfp32 {
1557515710 constexpr int kMaxQ = 8;
1557615711 static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
1557715712 auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step);
15578- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0 )*sizeof(block_q8), 0, 1, nullptr};
15713+ DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q )*sizeof(block_q8), 0, 1, nullptr};
1557915714 for (int iq = 0; iq < q_step/nrc_q; ++iq) {
1558015715 mul_mat(D, kh.block, kh.stride, info, k_step);
1558115716 info.cur_y += nrc_q;
@@ -15597,7 +15732,7 @@ struct FlashQKfp32 {
1559715732 static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
1559815733 const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
1559915734 auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq);
15600- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0 )*sizeof(block_q8), 0, 1, nullptr};
15735+ DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q )*sizeof(block_q8), 0, 1, nullptr};
1560115736 for (int iq = 0; iq < nq/nrc_q; ++iq) {
1560215737 mul_mat(D, kh.block, kh.stride, info, k_step);
1560315738 info.cur_y += nrc_q;
@@ -15685,7 +15820,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
1568515820 FlashMS<q_step, k_step>& fms,
1568615821 FlashQKV<Dv, q_step, k_step>& fqkv,
1568715822 const float * q, const char * mask, float * qkv) {
15688- typename KHelper::block_q8 q8[q_step*(Dk/QK8_0 )];
15823+ typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q )];
1568915824#if FA_TIMING
1569015825 Perf perf(false);
1569115826#endif
@@ -15773,7 +15908,7 @@ struct FlashAttn {
1577315908 void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1577415909 const float * q, const char * mask, float * qkv) {
1577515910 if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
15776- std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
15911+ std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> || std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
1577715912 std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) {
1577815913 compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
1577915914 kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
@@ -15782,12 +15917,12 @@ struct FlashAttn {
1578215917 if (nq1 >= 8) {
1578315918#if FA_TIMING
1578415919 auto t1 = Perf::cur_time();
15785- HelperQ80R4 <Dk, k_step> khr4(nk1, kh);
15920+ HelperQ80R8 <Dk, k_step> khr4(nk1, kh);
1578615921 Perf::instance().accum(4, t1);
1578715922#else
15788- HelperQ80R4 <Dk, k_step> khr4(nk1, kh);
15923+ HelperQ80R8 <Dk, k_step> khr4(nk1, kh);
1578915924#endif
15790- compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4 <Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
15925+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8 <Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
1579115926 khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
1579215927 } else{
1579315928 compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
@@ -16311,6 +16446,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
1631116446 HelperQ80<Dv, k_step> vh(v, stride_v);
1631216447 iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
1631316448 } break;
16449+ case GGML_TYPE_Q8_KV: {
16450+ HelperQ8KV<Dv, k_step> vh(v, stride_v);
16451+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
16452+ } break;
1631416453 case GGML_TYPE_Q6_0: {
1631516454 HelperQ60<Dv, k_step> vh(v, stride_v);
1631616455 iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -16348,6 +16487,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
1634816487 HelperQ80<Dk, k_step> kh(k, stride_k);
1634916488 iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
1635016489 } break;
16490+ case GGML_TYPE_Q8_KV: {
16491+ HelperQ8KV<Dk, k_step> kh(k, stride_k);
16492+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
16493+ } break;
1635116494 case GGML_TYPE_Q6_0: {
1635216495 HelperQ60<Dk, k_step> kh(k, stride_k);
1635316496 iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -16379,7 +16522,7 @@ inline bool flash_attn_is_supported(ggml_type type) {
1637916522 if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
1638016523 type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
1638116524#else
16382- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
16525+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV ) return true;
1638316526#endif
1638416527 return false;
1638516528}
0 commit comments