Skip to content

Commit a0ebfdd

Browse files
ikawrakowIwan Kawrakow
andauthored
Q8_KV: 8-bit quantization type targeting the KV cache (#208)
* Adding q8_KV - Basics + AVX2 gemm/gemv * q8_KV: Better AVX2 gemm * q8_KV: Better Zen4 gemm We get 225.7 t/s for L3-8B. In comparison q8_0 without run-tinme-repacking is at 169 t/s. * q8_KV: AVX2 gemm/gemv We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr. * q8_KV: be able to use it for K cache This required quite a few fixes in ggml and llama.cpp: * ggml: do not calculate row size as n/block_size*type_size. I had removed most of it when implementing the quants with per row scale, bit it was stull lurking in ggml_copy. Not sure if these were the last remnants of ggmil-style row sizes, or if there are still places left * llama.cpp: get rid of the the 1d K cache assumption. Create and manage the K-cache as a 2D tensor so we can have per row meta data as needed by q8_KV. Using q8_KV for K-cache results in non-negligible performance gains. More details to follow, but for DeepSeek-Lite with MLA, we get 18% speedup for PP-8192 compared to q8_0 K-cache. * q8_KV: be able to use it for K cache in FA * q8_KV: repack it for K*Q in FA * q8_KV: slightly faster gemv on Zen4 * q8_KV: slightly faster gemv on Zen4 * q8_KV: ARM_NEON We get PP-512 = 167 t/s for L3-8B without interleaving! We do the interleaving on the fly, so I wonder if this could be done for other quants as well. * q8_KV: use it in FA on NEON * q8_KV_r8 - repacked q8_KV On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s) This makes no sense whatsoever as the q8_KV_r8 GEMM is basically the q8_k_r8 GEMM with the unnecessary block stuff removed (so, one would think that it would be faster). * q8_KV_r8: don't use nrc_y = 16 on Zen4 This is faster - 350 t/s. Why? Much better than the 290 t/s we had before, but still slower than the 370 t/s for q8_k_r8. * q8_KV: nrc_y = 16 also doesn't pay off in FA * Minor --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 047ba89 commit a0ebfdd

File tree

11 files changed

+983
-34
lines changed

11 files changed

+983
-34
lines changed

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,6 +2259,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
22592259
if (s == "q6_0") {
22602260
return GGML_TYPE_Q6_0;
22612261
}
2262+
if (s == "q8_KV") {
2263+
return GGML_TYPE_Q8_KV;
2264+
}
22622265

22632266
throw std::runtime_error("Invalid cache type: " + s);
22642267
}

examples/llama-bench/llama-bench.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
339339
if (s == "q6_0") {
340340
return GGML_TYPE_Q6_0;
341341
}
342+
if (s == "q8_KV") {
343+
return GGML_TYPE_Q8_KV;
344+
}
342345

343346
return GGML_TYPE_COUNT;
344347
}

examples/quantize/quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
5656
{ "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", },
5757
{ "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", },
5858
{ "Q8_0_R8", LLAMA_FTYPE_MOSTLY_Q8_0_R8, " 8.50 bpw quantization", },
59+
{ "Q8_KV", LLAMA_FTYPE_MOSTLY_Q8_KV, " 8.00 bpw quantization", },
5960
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
6061
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
6162
{ "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", },
@@ -82,6 +83,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
8283
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
8384
{ "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", },
8485
{ "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", },
86+
{ "Q8_KV_R8", LLAMA_FTYPE_MOSTLY_Q8_KV_R8, "Q8_KV repacked", },
8587
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
8688
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
8789
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },

ggml/include/ggml.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ extern "C" {
416416
GGML_TYPE_Q8_K32 = 148,
417417
GGML_TYPE_Q8_KR8 = 149,
418418
GGML_TYPE_Q8_K128 = 150,
419+
GGML_TYPE_Q8_KV = 151,
419420

420421
GGML_TYPE_Q4_0_R8 = 202,
421422
GGML_TYPE_Q5_0_R4 = 206,
@@ -442,6 +443,7 @@ extern "C" {
442443
GGML_TYPE_IQ4_K_R4 = 339,
443444
GGML_TYPE_IQ5_K_R4 = 340,
444445
GGML_TYPE_IQ4_KS_R4 = 344,
446+
GGML_TYPE_Q8_KV_R8 = 398,
445447
GGML_TYPE_Q8_K_R8 = 399,
446448
GGML_TYPE_COUNT,
447449
};
@@ -501,6 +503,7 @@ extern "C" {
501503
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
502504
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
503505
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
506+
GGML_FTYPE_MOSTLY_Q8_KV = 140, // except 1d tensors
504507
//
505508
GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
506509
GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors
@@ -527,6 +530,7 @@ extern "C" {
527530
GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors
528531
GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors
529532
GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors
533+
GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors
530534
GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
531535
};
532536

ggml/src/ggml-quants.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15214,8 +15214,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1521415214
case GGML_TYPE_IQ3_K_R4: break;
1521515215
case GGML_TYPE_IQ4_K_R4: break;
1521615216
case GGML_TYPE_IQ5_K_R4: break;
15217-
case GGML_TYPE_IQ4_KS_R4: break;
15218-
case GGML_TYPE_Q8_K_R8: break;
15217+
case GGML_TYPE_IQ4_KS_R4:break;
15218+
case GGML_TYPE_Q8_KV_R8: break;
15219+
case GGML_TYPE_Q8_K_R8: break;
15220+
case GGML_TYPE_Q8_KV: break;
1521915221
case GGML_TYPE_BF16_R16: break;
1522015222
case GGML_TYPE_Q4_0_4_4:
1522115223
case GGML_TYPE_Q4_0_4_8:

ggml/src/ggml.c

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,30 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
13621362
.from_float = quantize_row_q8_K128,
13631363
.row_meta_size = 0,
13641364
},
1365+
[GGML_TYPE_Q8_KV] = {
1366+
.type_name = "q8_KV",
1367+
.blck_size = 32,
1368+
.type_size = 32,
1369+
.is_quantized = true,
1370+
.to_float = (ggml_to_float_t) dequantize_row_q8_KV,
1371+
.from_float = quantize_row_q8_KV,
1372+
.from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_ref,
1373+
.vec_dot = vec_dot_q8_KV_q8_KV,
1374+
.vec_dot_type = GGML_TYPE_Q8_KV,
1375+
.row_meta_size = 8,
1376+
},
1377+
[GGML_TYPE_Q8_KV_R8] = {
1378+
.type_name = "q8_KV_r8",
1379+
.blck_size = 32,
1380+
.type_size = 32,
1381+
.is_quantized = true,
1382+
.to_float = (ggml_to_float_t) dequantize_row_q8_KV_r8,
1383+
.from_float = quantize_row_q8_KV_r8,
1384+
.from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_r8_ref,
1385+
.vec_dot = vec_dot_q8_KV_r8_q8_KV,
1386+
.vec_dot_type = GGML_TYPE_Q8_KV,
1387+
.row_meta_size = 4,
1388+
},
13651389
[GGML_TYPE_Q8_K16] = {
13661390
.type_name = "q8_K16",
13671391
.blck_size = 64,
@@ -4373,6 +4397,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
43734397
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
43744398
case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break;
43754399
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
4400+
case GGML_FTYPE_MOSTLY_Q8_KV: wtype = GGML_TYPE_Q8_KV; break;
43764401
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
43774402
case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break;
43784403
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
@@ -4384,6 +4409,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
43844409
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
43854410
case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
43864411
case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break;
4412+
case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break;
43874413
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
43884414
case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break;
43894415
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
@@ -9436,7 +9462,7 @@ static void ggml_compute_forward_dup_f16(
94369462
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
94379463

94389464
size_t id = 0;
9439-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
9465+
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
94409466
char * dst_ptr = (char *) dst->data;
94419467

94429468
for (int i03 = 0; i03 < ne03; i03++) {
@@ -9722,7 +9748,7 @@ static void ggml_compute_forward_dup_bf16(
97229748
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
97239749

97249750
size_t id = 0;
9725-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
9751+
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
97269752
char * dst_ptr = (char *) dst->data;
97279753

97289754
for (int i03 = 0; i03 < ne03; i03++) {
@@ -10042,7 +10068,7 @@ static void ggml_compute_forward_dup_f32(
1004210068
ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
1004310069

1004410070
size_t id = 0;
10045-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
10071+
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
1004610072
char * dst_ptr = (char *) dst->data;
1004710073

1004810074
for (int i03 = 0; i03 < ne03; i03++) {
@@ -10936,6 +10962,7 @@ static void ggml_compute_forward_add(
1093610962
case GGML_TYPE_Q6_K:
1093710963
case GGML_TYPE_Q6_K_R4:
1093810964
case GGML_TYPE_Q8_K_R8:
10965+
case GGML_TYPE_Q8_KV_R8:
1093910966
case GGML_TYPE_IQ2_XXS:
1094010967
case GGML_TYPE_IQ2_XXS_R4:
1094110968
case GGML_TYPE_IQ2_XS:
@@ -11406,6 +11433,7 @@ static void ggml_compute_forward_add1(
1140611433
case GGML_TYPE_Q6_K:
1140711434
case GGML_TYPE_Q6_K_R4:
1140811435
case GGML_TYPE_Q8_K_R8:
11436+
case GGML_TYPE_Q8_KV_R8:
1140911437
case GGML_TYPE_IQ2_XXS:
1141011438
case GGML_TYPE_IQ2_XXS_R4:
1141111439
case GGML_TYPE_IQ2_XS:
@@ -11573,6 +11601,7 @@ static void ggml_compute_forward_acc(
1157311601
case GGML_TYPE_Q6_K:
1157411602
case GGML_TYPE_Q6_K_R4:
1157511603
case GGML_TYPE_Q8_K_R8:
11604+
case GGML_TYPE_Q8_KV_R8:
1157611605
case GGML_TYPE_IQ2_XXS:
1157711606
case GGML_TYPE_IQ2_XXS_R4:
1157811607
case GGML_TYPE_IQ2_XS:
@@ -14061,7 +14090,7 @@ static void ggml_compute_forward_mul_mat(
1406114090
// nb01 >= nb00 - src0 is not transposed
1406214091
// compute by src0 rows
1406314092

14064-
#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE
14093+
#if GGML_USE_LLAMAFILE
1406514094
// broadcast factors
1406614095
const int64_t r2 = ne12 / ne02;
1406714096
const int64_t r3 = ne13 / ne03;
@@ -14344,7 +14373,7 @@ static void ggml_compute_forward_mul_mat_id(
1434414373

1434514374
char * wdata_src1_end = (src1->type == vec_dot_type) ?
1434614375
(char *) params->wdata :
14347-
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
14376+
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t));
1434814377

1434914378
struct mmid_row_mapping {
1435014379
int32_t i1;
@@ -14768,6 +14797,7 @@ static void ggml_compute_forward_out_prod(
1476814797
case GGML_TYPE_Q5_1:
1476914798
case GGML_TYPE_Q6_0:
1477014799
case GGML_TYPE_Q8_0:
14800+
case GGML_TYPE_Q8_KV:
1477114801
case GGML_TYPE_Q2_K:
1477214802
case GGML_TYPE_Q2_K_R4:
1477314803
case GGML_TYPE_Q3_K:
@@ -14779,6 +14809,7 @@ static void ggml_compute_forward_out_prod(
1477914809
case GGML_TYPE_Q6_K:
1478014810
case GGML_TYPE_Q6_K_R4:
1478114811
case GGML_TYPE_Q8_K_R8:
14812+
case GGML_TYPE_Q8_KV_R8:
1478214813
case GGML_TYPE_IQ2_XXS:
1478314814
case GGML_TYPE_IQ2_XXS_R4:
1478414815
case GGML_TYPE_IQ2_XS:
@@ -15186,6 +15217,7 @@ static void ggml_compute_forward_set(
1518615217
case GGML_TYPE_Q6_K:
1518715218
case GGML_TYPE_Q6_K_R4:
1518815219
case GGML_TYPE_Q8_K_R8:
15220+
case GGML_TYPE_Q8_KV_R8:
1518915221
case GGML_TYPE_IQ2_XXS:
1519015222
case GGML_TYPE_IQ2_XXS_R4:
1519115223
case GGML_TYPE_IQ2_XS:
@@ -15473,6 +15505,7 @@ static void ggml_compute_forward_get_rows(
1547315505
case GGML_TYPE_Q5_1:
1547415506
case GGML_TYPE_Q6_0:
1547515507
case GGML_TYPE_Q8_0:
15508+
case GGML_TYPE_Q8_KV:
1547615509
case GGML_TYPE_Q8_1:
1547715510
case GGML_TYPE_Q8_0_X4:
1547815511
case GGML_TYPE_Q8_1_X4:
@@ -15487,6 +15520,7 @@ static void ggml_compute_forward_get_rows(
1548715520
case GGML_TYPE_Q6_K:
1548815521
case GGML_TYPE_Q6_K_R4:
1548915522
case GGML_TYPE_Q8_K_R8:
15523+
case GGML_TYPE_Q8_KV_R8:
1549015524
case GGML_TYPE_IQ2_XXS:
1549115525
case GGML_TYPE_IQ2_XXS_R4:
1549215526
case GGML_TYPE_IQ2_XS:
@@ -16116,6 +16150,7 @@ static void ggml_compute_forward_clamp(
1611616150
case GGML_TYPE_Q6_K:
1611716151
case GGML_TYPE_Q6_K_R4:
1611816152
case GGML_TYPE_Q8_K_R8:
16153+
case GGML_TYPE_Q8_KV_R8:
1611916154
case GGML_TYPE_Q8_KR8:
1612016155
case GGML_TYPE_IQ2_XXS:
1612116156
case GGML_TYPE_IQ2_XXS_R4:
@@ -16159,6 +16194,7 @@ static void ggml_compute_forward_clamp(
1615916194
case GGML_TYPE_Q8_K:
1616016195
case GGML_TYPE_Q8_K64:
1616116196
case GGML_TYPE_Q8_K128:
16197+
case GGML_TYPE_Q8_KV:
1616216198
case GGML_TYPE_Q8_K16:
1616316199
case GGML_TYPE_Q8_K32:
1616416200
case GGML_TYPE_Q4_0_4_4:
@@ -22970,6 +23006,7 @@ size_t ggml_quantize_chunk(
2297023006
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2297123007
case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2297223008
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
23009+
case GGML_TYPE_Q8_KV: result = quantize_q8_KV(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2297323010
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2297423011
case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2297523012
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
@@ -22981,6 +23018,7 @@ size_t ggml_quantize_chunk(
2298123018
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2298223019
case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2298323020
case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
23021+
case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2298423022
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2298523023
case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2298623024
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

0 commit comments

Comments
 (0)