Skip to content

Commit 8f004a0

Browse files
author
Iwan Kawrakow
committed
q8_KV: be able to use it for K cache in FA
1 parent 0280b8d commit 8f004a0

File tree

1 file changed

+155
-12
lines changed

1 file changed

+155
-12
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 155 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6171,6 +6171,80 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
61716171
}
61726172
}
61736173

6174+
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
6175+
template <int nrc_y>
6176+
static void mul_mat_q8_KV_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
6177+
GGML_ASSERT(nrc_x%8 == 0);
6178+
GGML_ASSERT(n%32 == 0);
6179+
#ifndef HAVE_FANCY_SIMD
6180+
auto m1 = _mm256_set1_epi16(1);
6181+
#endif
6182+
int nb = n / 16;
6183+
__m256i acc[nrc_y] = {};
6184+
__m256i qx[4];
6185+
float dy[nrc_y];
6186+
#ifdef HAVE_FANCY_SIMD
6187+
float sy[nrc_y];
6188+
#endif
6189+
const int8_t * q8y[nrc_y];
6190+
for (int iy = 0; iy < nrc_y; ++iy) {
6191+
auto dptr = (const float *)info.src1_row(iy);
6192+
dy[iy] = dptr[0];
6193+
#ifdef HAVE_FANCY_SIMD
6194+
auto iptr = (const int32_t *)(dptr + 1);
6195+
sy[iy] = -128*iptr[0];
6196+
#endif
6197+
q8y[iy] = (const int8_t *)(dptr + 2);
6198+
}
6199+
for (int ix = 0; ix < nrc_x; ix += 8) {
6200+
auto dptr = (const float *)((const char *)vx + ix*bx);
6201+
auto dx = _mm256_loadu_ps(dptr);
6202+
auto q8x = (const int8_t *)(dptr + 8);
6203+
for (int ib = 0; ib < nb; ++ib) { // Blocks of 32
6204+
qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
6205+
qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
6206+
qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
6207+
qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
6208+
#ifndef HAVE_FANCY_SIMD
6209+
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
6210+
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
6211+
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
6212+
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
6213+
#endif
6214+
for (int iy = 0; iy < nrc_y; ++iy) {
6215+
auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
6216+
auto y = MM256_SET_M128I(y128, y128);
6217+
#ifdef HAVE_FANCY_SIMD
6218+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
6219+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
6220+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
6221+
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
6222+
#else
6223+
auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
6224+
auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
6225+
auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
6226+
auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
6227+
auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
6228+
auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
6229+
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
6230+
#endif
6231+
}
6232+
for (int iy = 0; iy < nrc_y; ++iy) {
6233+
auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
6234+
#ifdef HAVE_FANCY_SIMD
6235+
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
6236+
#endif
6237+
info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
6238+
acc[iy] = _mm256_setzero_si256();
6239+
}
6240+
}
6241+
for (int iy = 0; iy < nrc_y; ++iy) {
6242+
info.store(ix, iy, acc[iy]);
6243+
acc[iy] = _mm256_setzero_ps();
6244+
}
6245+
}
6246+
}
6247+
61746248
template <int nrc_y>
61756249
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
61766250
GGML_ASSERT(nrc_x%8 == 0);
@@ -14516,13 +14590,50 @@ struct HelperF16 final : public BaseHelper<step> {
1451614590
}
1451714591
};
1451814592

14593+
template <int D> struct block_q8_KV {
14594+
float d;
14595+
int s;
14596+
int8_t qs[D];
14597+
};
14598+
14599+
template <int D, int step>
14600+
struct HelperQ8KV final : public BaseHelper<step> {
14601+
using Base = BaseHelper<step>;
14602+
using block_q8 = block_q8_KV<D>;
14603+
constexpr static int block_size_q = D;
14604+
HelperQ8KV(const char * data, int stride) : Base(data, stride) {}
14605+
14606+
// Needed for v * softmax(k * q)
14607+
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
14608+
auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
14609+
#ifdef __aarch64__
14610+
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
14611+
int ii = j%QK8_0;
14612+
auto qs = vld1_s8_x2(dl->qs + ii);
14613+
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
14614+
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
14615+
#else
14616+
auto vd = F16::set1(q8->d);
14617+
#ifdef HAVE_FANCY_SIMD
14618+
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0))));
14619+
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1))));
14620+
#else
14621+
v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0)))));
14622+
v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8)))));
14623+
#endif
14624+
#endif
14625+
}
14626+
};
14627+
1451914628
template <int D, int step>
1452014629
struct HelperQ80 final : public BaseHelper<step> {
1452114630
using Base = BaseHelper<step>;
1452214631
#ifdef HAVE_FANCY_SIMD
1452314632
using block_q8 = block_q8_1;
14633+
constexpr static int block_size_q = QK8_1;
1452414634
#else
1452514635
using block_q8 = block_q8_0;
14636+
constexpr static int block_size_q = QK8_0;
1452614637
#endif
1452714638
HelperQ80(const char * data, int stride) : Base(data, stride) {}
1452814639

@@ -14566,23 +14677,33 @@ struct HelperQ80 final : public BaseHelper<step> {
1456614677
y += D/QK8_1;
1456714678
}
1456814679
}
14680+
14681+
static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) {
14682+
for (int i = 0; i < nq; ++i) {
14683+
quantize_row_q8_KV(q, y, D);
14684+
q += stride_q;
14685+
++y;
14686+
}
14687+
}
1456914688
};
1457014689

1457114690
template <int D, int step>
14572-
struct HelperQ80R4 : public BaseHelper<step> {
14691+
struct HelperQ80R8 : public BaseHelper<step> {
1457314692
using Base = BaseHelper<step>;
1457414693
#ifdef __AVX2__
14694+
constexpr static int block_size_q = QK8_1;
1457514695
using block_q8 = block_q8_1;
1457614696
#else
14697+
constexpr static int block_size_q = QK8_0;
1457714698
using block_q8 = block_q8_0;
1457814699
#endif
14579-
HelperQ80R4(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
14700+
HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
1458014701
r4 = repack(nk, q8);
1458114702
Base::data = (const char *)r4.data();
1458214703
Base::stride = (D/QK8_0)*sizeof(block_q8_0);
1458314704
}
1458414705

14585-
static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step> q8) {
14706+
static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
1458614707
static_assert(D%QK8_0 == 0);
1458714708
GGML_ASSERT(nk%8 == 0);
1458814709
constexpr int nblock = D/QK8_0;
@@ -14685,6 +14806,7 @@ template <int D, int step>
1468514806
struct HelperQ40 final : public BaseHelper<step> {
1468614807
using Base = BaseHelper<step>;
1468714808
using block_q8 = block_q8_0;
14809+
constexpr static int block_size_q = QK8_0;
1468814810
HelperQ40(const char * data, int stride) : Base(data, stride) {}
1468914811

1469014812
// Needed for v * softmax(k * q)
@@ -14728,6 +14850,7 @@ template <int D, int step>
1472814850
struct HelperQ41 final : public BaseHelper<step> {
1472914851
using Base = BaseHelper<step>;
1473014852
using block_q8 = block_q8_1;
14853+
constexpr static int block_size_q = QK8_1;
1473114854
HelperQ41(const char * data, int stride) : Base(data, stride) {}
1473214855

1473314856
// Needed for v * softmax(k * q)
@@ -14818,8 +14941,10 @@ template <int D, int step>
1481814941
struct HelperQ60 final : public BaseHelper<step> {
1481914942
#ifdef __aarch64__
1482014943
using block_q8 = block_q8_0;
14944+
constexpr static int block_size_q = QK8_0;
1482114945
#else
1482214946
using block_q8 = block_q8_1;
14947+
constexpr static int block_size_q = QK8_1;
1482314948
#endif
1482414949
using Base = BaseHelper<step>;
1482514950
HelperQ60(const char * data, int stride) : Base(data, stride) {}
@@ -15526,7 +15651,17 @@ struct FlashQKfp32 {
1552615651
#endif
1552715652
#endif
1552815653
}
15529-
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
15654+
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
15655+
#ifdef __aarch64__
15656+
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
15657+
#else
15658+
#ifdef HAVE_FANCY_SIMD
15659+
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
15660+
#endif
15661+
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
15662+
#endif
15663+
}
15664+
else if constexpr (std::is_same_v<KHelper, HelperQ80R8<D, k_step>>) {
1553015665
#ifdef __aarch64__
1553115666
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
1553215667
#else
@@ -15575,7 +15710,7 @@ struct FlashQKfp32 {
1557515710
constexpr int kMaxQ = 8;
1557615711
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
1557715712
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step);
15578-
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
15713+
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
1557915714
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
1558015715
mul_mat(D, kh.block, kh.stride, info, k_step);
1558115716
info.cur_y += nrc_q;
@@ -15597,7 +15732,7 @@ struct FlashQKfp32 {
1559715732
static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
1559815733
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
1559915734
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq);
15600-
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
15735+
DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
1560115736
for (int iq = 0; iq < nq/nrc_q; ++iq) {
1560215737
mul_mat(D, kh.block, kh.stride, info, k_step);
1560315738
info.cur_y += nrc_q;
@@ -15685,7 +15820,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
1568515820
FlashMS<q_step, k_step>& fms,
1568615821
FlashQKV<Dv, q_step, k_step>& fqkv,
1568715822
const float * q, const char * mask, float * qkv) {
15688-
typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)];
15823+
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
1568915824
#if FA_TIMING
1569015825
Perf perf(false);
1569115826
#endif
@@ -15773,7 +15908,7 @@ struct FlashAttn {
1577315908
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
1577415909
const float * q, const char * mask, float * qkv) {
1577515910
if constexpr (std::is_same_v<KHelper, HelperQ40<Dk, k_step>> || std::is_same_v<KHelper, HelperQ41<Dk, k_step>> ||
15776-
std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> ||
15911+
std::is_same_v<KHelper, HelperIQ4nl<Dk, k_step>> || std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>> ||
1577715912
std::is_same_v<KHelper, HelperQ60<Dk, k_step>>) {
1577815913
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
1577915914
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
@@ -15782,12 +15917,12 @@ struct FlashAttn {
1578215917
if (nq1 >= 8) {
1578315918
#if FA_TIMING
1578415919
auto t1 = Perf::cur_time();
15785-
HelperQ80R4<Dk, k_step> khr4(nk1, kh);
15920+
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
1578615921
Perf::instance().accum(4, t1);
1578715922
#else
15788-
HelperQ80R4<Dk, k_step> khr4(nk1, kh);
15923+
HelperQ80R8<Dk, k_step> khr4(nk1, kh);
1578915924
#endif
15790-
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
15925+
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
1579115926
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
1579215927
} else{
1579315928
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
@@ -16311,6 +16446,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
1631116446
HelperQ80<Dv, k_step> vh(v, stride_v);
1631216447
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
1631316448
} break;
16449+
case GGML_TYPE_Q8_KV: {
16450+
HelperQ8KV<Dv, k_step> vh(v, stride_v);
16451+
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
16452+
} break;
1631416453
case GGML_TYPE_Q6_0: {
1631516454
HelperQ60<Dv, k_step> vh(v, stride_v);
1631616455
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -16348,6 +16487,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
1634816487
HelperQ80<Dk, k_step> kh(k, stride_k);
1634916488
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
1635016489
} break;
16490+
case GGML_TYPE_Q8_KV: {
16491+
HelperQ8KV<Dk, k_step> kh(k, stride_k);
16492+
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
16493+
} break;
1635116494
case GGML_TYPE_Q6_0: {
1635216495
HelperQ60<Dk, k_step> kh(k, stride_k);
1635316496
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -16379,7 +16522,7 @@ inline bool flash_attn_is_supported(ggml_type type) {
1637916522
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
1638016523
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
1638116524
#else
16382-
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
16525+
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true;
1638316526
#endif
1638416527
return false;
1638516528
}

0 commit comments

Comments
 (0)