Skip to content

Commit 0526db2

Browse files
author
Iwan Kawrakow
committed
q8_KV_r8 - repacked q8_KV
On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s) This makes no sense whatsoever as the q8_KV_r8 GEMM is basically the q8_k_r8 GEMM with the unnecessary block stuff removed (so, one would think that it would be faster).
1 parent 1d57041 commit 0526db2

File tree

9 files changed

+215
-6
lines changed

9 files changed

+215
-6
lines changed

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
8383
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
8484
{ "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", },
8585
{ "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", },
86+
{ "Q8_KV_R8", LLAMA_FTYPE_MOSTLY_Q8_KV_R8, "Q8_KV repacked", },
8687
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
8788
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
8889
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },

ggml/include/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ extern "C" {
443443
GGML_TYPE_IQ4_K_R4 = 339,
444444
GGML_TYPE_IQ5_K_R4 = 340,
445445
GGML_TYPE_IQ4_KS_R4 = 344,
446+
GGML_TYPE_Q8_KV_R8 = 398,
446447
GGML_TYPE_Q8_K_R8 = 399,
447448
GGML_TYPE_COUNT,
448449
};
@@ -529,6 +530,7 @@ extern "C" {
529530
GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors
530531
GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors
531532
GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors
533+
GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors
532534
GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
533535
};
534536

ggml/src/ggml-quants.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15214,9 +15214,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1521415214
case GGML_TYPE_IQ3_K_R4: break;
1521515215
case GGML_TYPE_IQ4_K_R4: break;
1521615216
case GGML_TYPE_IQ5_K_R4: break;
15217-
case GGML_TYPE_IQ4_KS_R4: break;
15218-
case GGML_TYPE_Q8_K_R8: break;
15219-
case GGML_TYPE_Q8_KV: break;
15217+
case GGML_TYPE_IQ4_KS_R4:break;
15218+
case GGML_TYPE_Q8_KV_R8: break;
15219+
case GGML_TYPE_Q8_K_R8: break;
15220+
case GGML_TYPE_Q8_KV: break;
1522015221
case GGML_TYPE_BF16_R16: break;
1522115222
case GGML_TYPE_Q4_0_4_4:
1522215223
case GGML_TYPE_Q4_0_4_8:

ggml/src/ggml.c

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
13741374
.vec_dot_type = GGML_TYPE_Q8_KV,
13751375
.row_meta_size = 8,
13761376
},
1377+
[GGML_TYPE_Q8_KV_R8] = {
1378+
.type_name = "q8_KV_r8",
1379+
.blck_size = 32,
1380+
.type_size = 32,
1381+
.is_quantized = true,
1382+
.to_float = (ggml_to_float_t) dequantize_row_q8_KV_r8,
1383+
.from_float = quantize_row_q8_KV_r8,
1384+
.from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_r8_ref,
1385+
.vec_dot = vec_dot_q8_KV_r8_q8_KV,
1386+
.vec_dot_type = GGML_TYPE_Q8_KV,
1387+
.row_meta_size = 4,
1388+
},
13771389
[GGML_TYPE_Q8_K16] = {
13781390
.type_name = "q8_K16",
13791391
.blck_size = 64,
@@ -4397,6 +4409,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
43974409
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
43984410
case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
43994411
case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break;
4412+
case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break;
44004413
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
44014414
case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break;
44024415
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
@@ -10949,6 +10962,7 @@ static void ggml_compute_forward_add(
1094910962
case GGML_TYPE_Q6_K:
1095010963
case GGML_TYPE_Q6_K_R4:
1095110964
case GGML_TYPE_Q8_K_R8:
10965+
case GGML_TYPE_Q8_KV_R8:
1095210966
case GGML_TYPE_IQ2_XXS:
1095310967
case GGML_TYPE_IQ2_XXS_R4:
1095410968
case GGML_TYPE_IQ2_XS:
@@ -11419,6 +11433,7 @@ static void ggml_compute_forward_add1(
1141911433
case GGML_TYPE_Q6_K:
1142011434
case GGML_TYPE_Q6_K_R4:
1142111435
case GGML_TYPE_Q8_K_R8:
11436+
case GGML_TYPE_Q8_KV_R8:
1142211437
case GGML_TYPE_IQ2_XXS:
1142311438
case GGML_TYPE_IQ2_XXS_R4:
1142411439
case GGML_TYPE_IQ2_XS:
@@ -11586,6 +11601,7 @@ static void ggml_compute_forward_acc(
1158611601
case GGML_TYPE_Q6_K:
1158711602
case GGML_TYPE_Q6_K_R4:
1158811603
case GGML_TYPE_Q8_K_R8:
11604+
case GGML_TYPE_Q8_KV_R8:
1158911605
case GGML_TYPE_IQ2_XXS:
1159011606
case GGML_TYPE_IQ2_XXS_R4:
1159111607
case GGML_TYPE_IQ2_XS:
@@ -14793,6 +14809,7 @@ static void ggml_compute_forward_out_prod(
1479314809
case GGML_TYPE_Q6_K:
1479414810
case GGML_TYPE_Q6_K_R4:
1479514811
case GGML_TYPE_Q8_K_R8:
14812+
case GGML_TYPE_Q8_KV_R8:
1479614813
case GGML_TYPE_IQ2_XXS:
1479714814
case GGML_TYPE_IQ2_XXS_R4:
1479814815
case GGML_TYPE_IQ2_XS:
@@ -15200,6 +15217,7 @@ static void ggml_compute_forward_set(
1520015217
case GGML_TYPE_Q6_K:
1520115218
case GGML_TYPE_Q6_K_R4:
1520215219
case GGML_TYPE_Q8_K_R8:
15220+
case GGML_TYPE_Q8_KV_R8:
1520315221
case GGML_TYPE_IQ2_XXS:
1520415222
case GGML_TYPE_IQ2_XXS_R4:
1520515223
case GGML_TYPE_IQ2_XS:
@@ -15502,6 +15520,7 @@ static void ggml_compute_forward_get_rows(
1550215520
case GGML_TYPE_Q6_K:
1550315521
case GGML_TYPE_Q6_K_R4:
1550415522
case GGML_TYPE_Q8_K_R8:
15523+
case GGML_TYPE_Q8_KV_R8:
1550515524
case GGML_TYPE_IQ2_XXS:
1550615525
case GGML_TYPE_IQ2_XXS_R4:
1550715526
case GGML_TYPE_IQ2_XS:
@@ -16131,6 +16150,7 @@ static void ggml_compute_forward_clamp(
1613116150
case GGML_TYPE_Q6_K:
1613216151
case GGML_TYPE_Q6_K_R4:
1613316152
case GGML_TYPE_Q8_K_R8:
16153+
case GGML_TYPE_Q8_KV_R8:
1613416154
case GGML_TYPE_Q8_KR8:
1613516155
case GGML_TYPE_IQ2_XXS:
1613616156
case GGML_TYPE_IQ2_XXS_R4:
@@ -22998,6 +23018,7 @@ size_t ggml_quantize_chunk(
2299823018
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2299923019
case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2300023020
case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
23021+
case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2300123022
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2300223023
case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2300323024
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ struct MulMat {
270270
case GGML_TYPE_Q4_K_R4:
271271
case GGML_TYPE_Q5_K_R4:
272272
case GGML_TYPE_Q8_KV:
273+
case GGML_TYPE_Q8_KV_R8:
273274
case GGML_TYPE_Q8_K_R8: return 8;
274275
case GGML_TYPE_Q4_0_R8:
275276
case GGML_TYPE_Q8_0_R8:
@@ -303,6 +304,7 @@ struct MulMat {
303304
case GGML_TYPE_Q4_0_R8:
304305
case GGML_TYPE_Q8_0_R8:
305306
case GGML_TYPE_Q8_KV:
307+
case GGML_TYPE_Q8_KV_R8:
306308
case GGML_TYPE_Q8_K_R8: return 8;
307309
case GGML_TYPE_BF16_R16: return 16;
308310
default: return 1;
@@ -9352,6 +9354,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
93529354
mm.funcs[7] = mul_mat_q8_KV_q8_KV<8>;
93539355
#ifdef HAVE_FANCY_SIMD
93549356
mm.func16 = mul_mat_q8_KV_q8_KV<16>;
9357+
#endif
9358+
expected_typeB = GGML_TYPE_Q8_KV;
9359+
break;
9360+
case GGML_TYPE_Q8_KV_R8:
9361+
assert (ne00 % 32 == 0);
9362+
mm.funcs[0] = mul_mat_q8_KV_r8_q8_KV<1>;
9363+
mm.funcs[1] = mul_mat_q8_KV_r8_q8_KV<2>;
9364+
mm.funcs[2] = mul_mat_q8_KV_r8_q8_KV<3>;
9365+
mm.funcs[3] = mul_mat_q8_KV_r8_q8_KV<4>;
9366+
mm.funcs[4] = mul_mat_q8_KV_r8_q8_KV<5>;
9367+
mm.funcs[5] = mul_mat_q8_KV_r8_q8_KV<6>;
9368+
mm.funcs[6] = mul_mat_q8_KV_r8_q8_KV<7>;
9369+
mm.funcs[7] = mul_mat_q8_KV_r8_q8_KV<8>;
9370+
#ifdef HAVE_FANCY_SIMD
9371+
mm.func16 = mul_mat_q8_KV_r8_q8_KV<16>;
93559372
#endif
93569373
expected_typeB = GGML_TYPE_Q8_KV;
93579374
break;
@@ -14364,6 +14381,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
1436414381
m.func16 = mul_mat_q8_KV_q8_KV<16>;
1436514382
expected_Btype = GGML_TYPE_Q8_KV;
1436614383
break;
14384+
case GGML_TYPE_Q8_KV_R8:
14385+
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_r8_q8_KV);
14386+
expected_Btype = GGML_TYPE_Q8_KV;
14387+
break;
1436714388
case GGML_TYPE_IQ2_K_R4:
1436814389
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
1436914390
expected_Btype = GGML_TYPE_Q8_K;

ggml/src/iqk/iqk_quantize.cpp

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5509,6 +5509,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
55095509
GGML_UNUSED(by);
55105510
}
55115511

5512+
//
5513+
// ========================================= q8_KV_r8
5514+
//
5515+
5516+
void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) {
5517+
quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
5518+
}
5519+
5520+
void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) {
5521+
quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
5522+
}
5523+
5524+
static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) {
5525+
GGML_ASSERT(nrows%8 == 0);
5526+
GGML_ASSERT(n_per_row%16 == 0);
5527+
auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
5528+
auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
5529+
const int8_t * x8[8];
5530+
#ifdef __ARM_NEON
5531+
int8x16x2_t m0, m1, m2, m3;
5532+
#endif
5533+
for (int row = 0; row < nrows; row += 8) {
5534+
auto dy = (float *)cy;
5535+
auto qy = (int8_t *)(dy + 8);
5536+
for (int k = 0; k < 8; ++k) {
5537+
auto dx = (const float *)(cx + k*row_size_x);
5538+
dy[k] = dx[0];
5539+
x8[k] = (const int8_t *)(dx + 2);
5540+
}
5541+
for (int ib = 0; ib < n_per_row/16; ++ib) {
5542+
#ifdef __AVX2__
5543+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
5544+
auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
5545+
auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
5546+
auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
5547+
auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
5548+
auto t0 = _mm256_unpacklo_epi32(m0, m1);
5549+
auto t1 = _mm256_unpacklo_epi32(m2, m3);
5550+
auto t2 = _mm256_unpackhi_epi32(m0, m1);
5551+
auto t3 = _mm256_unpackhi_epi32(m2, m3);
5552+
m0 = _mm256_unpacklo_epi64(t0, t1);
5553+
m1 = _mm256_unpackhi_epi64(t0, t1);
5554+
m2 = _mm256_unpacklo_epi64(t2, t3);
5555+
m3 = _mm256_unpackhi_epi64(t2, t3);
5556+
#ifdef HAVE_FANCY_SIMD
5557+
if (online) {
5558+
m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
5559+
m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
5560+
m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
5561+
m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
5562+
}
5563+
#endif
5564+
_mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0);
5565+
_mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1);
5566+
_mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2);
5567+
_mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3);
5568+
#elif defined __ARM_NEON
5569+
m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
5570+
m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
5571+
m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
5572+
m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
5573+
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
5574+
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
5575+
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
5576+
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
5577+
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
5578+
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
5579+
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
5580+
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
5581+
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
5582+
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
5583+
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
5584+
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
5585+
vst1q_s8_x2(qy + 0 + 128*ib, m0);
5586+
vst1q_s8_x2(qy + 32 + 128*ib, m1);
5587+
vst1q_s8_x2(qy + 64 + 128*ib, m2);
5588+
vst1q_s8_x2(qy + 96 + 128*ib, m3);
5589+
#else
5590+
// TODO
5591+
for (int l = 0; l < 4; ++l) {
5592+
for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
5593+
y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
5594+
y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
5595+
}
5596+
}
5597+
#endif
5598+
5599+
}
5600+
cx += 8*row_size_x;
5601+
cy += online ? 8*row_size_x : 8*row_size_y;
5602+
//So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row
5603+
}
5604+
}
5605+
#ifdef HAVE_FANCY_SIMD
5606+
static void modify_q8_KV_r8(int64_t k, char * cy) {
5607+
int8_t * q8 = (int8_t *)(cy + 8*sizeof(float));
5608+
for (int j = 0; j < k; ++j) q8[j] += 127;
5609+
}
5610+
#endif
5611+
5612+
size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) {
5613+
GGML_ASSERT(nrows%8 == 0);
5614+
GGML_ASSERT(n_per_row%16 == 0);
5615+
char * qcur = (char *)dst;
5616+
auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
5617+
auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
5618+
std::vector<char> qtmp(8*row_size_0);
5619+
for (int row = 0; row < nrows; row += 8) {
5620+
quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix);
5621+
repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false);
5622+
qcur += 8*row_size_1;
5623+
src += 8*n_per_row;
5624+
}
5625+
return nrows*row_size_1;
5626+
}
5627+
5628+
void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) {
5629+
auto n_per_row = k/8;
5630+
float * y8[8];
5631+
for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k;
5632+
auto dptr = (const float *)vx;
5633+
auto q8 = (const int8_t *)(dptr + 8);
5634+
for (int ib = 0; ib < n_per_row/16; ++ib) {
5635+
for (int k = 0; k < 8; ++k) {
5636+
for (int l = 0; l < 4; ++l) {
5637+
for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i];
5638+
}
5639+
}
5640+
}
5641+
}
5642+
5643+
void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
5644+
#if GGML_USE_IQK_MULMAT
5645+
if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
5646+
return;
5647+
}
5648+
#endif
5649+
GGML_ASSERT(n%QK4_NL == 0);
5650+
GGML_ASSERT(nrc == 1);
5651+
GGML_UNUSED(bs);
5652+
GGML_UNUSED(bx);
5653+
GGML_UNUSED(by);
5654+
}
5655+
55125656
//
55135657
// ========================================= bf16_r4
55145658
//
@@ -6610,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) {
66106754
{ GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} },
66116755
#endif
66126756
#ifdef HAVE_FANCY_SIMD
6613-
{ GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
6614-
{ GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
6757+
{ GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
6758+
{ GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
6759+
{ GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} },
66156760
#endif
66166761
};
66176762
auto it = k_mod_map.find(tensor->type);
@@ -6670,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
66706815
{ GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} },
66716816
{ GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} },
66726817
{ GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} },
6818+
{ GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} },
66736819
#ifdef __AVX512BF16__
66746820
{ GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}},
66756821
{ GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_half>} },

ggml/src/iqk/iqk_quantize.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ size_t quantize_q8_KV(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
223223
void dequantize_row_q8_KV(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
224224
void vec_dot_q8_KV_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
225225

226+
void quantize_row_q8_KV_r8_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
227+
void quantize_row_q8_KV_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
228+
size_t quantize_q8_KV_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
229+
void dequantize_row_q8_KV_r8(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
230+
void vec_dot_q8_KV_r8_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
231+
226232
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
227233
void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
228234
void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ extern "C" {
207207
LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors
208208
LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors
209209
LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors
210+
LLAMA_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors
210211
LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
211212

212213
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file

0 commit comments

Comments
 (0)