Skip to content

Commit 0280b8d

Browse files
author
Iwan Kawrakow
committed
q8_KV: be able to use it for K cache
This required quite a few fixes in ggml and llama.cpp: * ggml: do not calculate row size as n/block_size*type_size. I had removed most of it when implementing the quants with per row scale, bit it was stull lurking in ggml_copy. Not sure if these were the last remnants of ggmil-style row sizes, or if there are still places left * llama.cpp: get rid of the the 1d K cache assumption. Create and manage the K-cache as a 2D tensor so we can have per row meta data as needed by q8_KV. Using q8_KV for K-cache results in non-negligible performance gains. More details to follow, but for DeepSeek-Lite with MLA, we get 18% speedup for PP-8192 compared to q8_0 K-cache.
1 parent a4ffe2e commit 0280b8d

File tree

4 files changed

+38
-14
lines changed

4 files changed

+38
-14
lines changed

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,6 +2259,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
22592259
if (s == "q6_0") {
22602260
return GGML_TYPE_Q6_0;
22612261
}
2262+
if (s == "q8_KV") {
2263+
return GGML_TYPE_Q8_KV;
2264+
}
22622265

22632266
throw std::runtime_error("Invalid cache type: " + s);
22642267
}

examples/llama-bench/llama-bench.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
339339
if (s == "q6_0") {
340340
return GGML_TYPE_Q6_0;
341341
}
342+
if (s == "q8_KV") {
343+
return GGML_TYPE_Q8_KV;
344+
}
342345

343346
return GGML_TYPE_COUNT;
344347
}

ggml/src/ggml.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,8 +1364,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
13641364
},
13651365
[GGML_TYPE_Q8_KV] = {
13661366
.type_name = "q8_KV",
1367-
.blck_size = 1,
1368-
.type_size = 1,
1367+
.blck_size = 32,
1368+
.type_size = 32,
13691369
.is_quantized = true,
13701370
.to_float = (ggml_to_float_t) dequantize_row_q8_KV,
13711371
.from_float = quantize_row_q8_KV,
@@ -9449,7 +9449,7 @@ static void ggml_compute_forward_dup_f16(
94499449
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
94509450

94519451
size_t id = 0;
9452-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
9452+
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
94539453
char * dst_ptr = (char *) dst->data;
94549454

94559455
for (int i03 = 0; i03 < ne03; i03++) {
@@ -9735,7 +9735,7 @@ static void ggml_compute_forward_dup_bf16(
97359735
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
97369736

97379737
size_t id = 0;
9738-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
9738+
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
97399739
char * dst_ptr = (char *) dst->data;
97409740

97419741
for (int i03 = 0; i03 < ne03; i03++) {
@@ -10055,7 +10055,7 @@ static void ggml_compute_forward_dup_f32(
1005510055
ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
1005610056

1005710057
size_t id = 0;
10058-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
10058+
size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
1005910059
char * dst_ptr = (char *) dst->data;
1006010060

1006110061
for (int i03 = 0; i03 < ne03; i03++) {
@@ -14357,7 +14357,7 @@ static void ggml_compute_forward_mul_mat_id(
1435714357

1435814358
char * wdata_src1_end = (src1->type == vec_dot_type) ?
1435914359
(char *) params->wdata :
14360-
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
14360+
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t));
1436114361

1436214362
struct mmid_row_mapping {
1436314363
int32_t i1;

src/llama.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3180,6 +3180,10 @@ static bool llama_kv_cache_init(
31803180
for (int i = 0; i < (int) n_layer; i++) {
31813181
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
31823182
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
3183+
const uint32_t n_head = hparams.n_head(i);
3184+
const uint32_t n_head_kv = hparams.n_head_kv(i);
3185+
const uint32_t n_embd_head_k= hparams.n_embd_head_k;
3186+
31833187

31843188
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
31853189
ggml_tensor * k;
@@ -3201,7 +3205,8 @@ static bool llama_kv_cache_init(
32013205
const uint32_t kv_lora_rank = hparams.n_lora_kv;
32023206
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
32033207
#if MLA_USE_TRANSPOSED_CACHE
3204-
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
3208+
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
3209+
//ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
32053210
#else
32063211
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
32073212
#endif
@@ -3215,7 +3220,10 @@ static bool llama_kv_cache_init(
32153220
n_mla++;
32163221
}
32173222
else {
3218-
k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
3223+
//printf("Creating cache tensors:\n");
3224+
//printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k);
3225+
//k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
3226+
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
32193227
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
32203228
ggml_format_name(k, "cache_k_l%d", i);
32213229
ggml_format_name(v, "cache_v_l%d", i);
@@ -8285,11 +8293,20 @@ static void llm_build_kv_store(
82858293
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
82868294
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
82878295

8296+
const int64_t n_head = hparams.n_head(il);
8297+
const int64_t n_head_kv = hparams.n_head_kv(il);
8298+
const int64_t n_embd_head_k = hparams.n_embd_head_k;
8299+
const int64_t n_embd_head_v = hparams.n_embd_head_v;
8300+
82888301
GGML_ASSERT(kv.size == n_ctx);
82898302

8290-
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
8291-
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
8292-
cb(k_cache_view, "k_cache_view", il);
8303+
//struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
8304+
// (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
8305+
//cb(k_cache_view, "k_cache_view", il);
8306+
8307+
auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k);
8308+
ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv,
8309+
k_row_size, k_row_size*n_head_kv*kv_head);
82938310

82948311
// note: storing RoPE-ed version of K in the KV cache
82958312
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
@@ -8708,7 +8725,7 @@ static struct ggml_tensor * llm_build_kqv(
87088725
struct ggml_tensor * k =
87098726
ggml_view_3d(ctx, kv.k_l[il],
87108727
n_embd_head_k, n_kv, n_head_kv,
8711-
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
8728+
ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
87128729
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
87138730
0);
87148731
cb(k, "k", il);
@@ -13509,8 +13526,9 @@ struct llm_build_context {
1350913526
ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
1351013527
cb(kvr, "kvr", il);
1351113528

13512-
ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope),
13513-
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head);
13529+
auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
13530+
ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens,
13531+
row_size, row_size*kv_head);
1351413532
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view));
1351513533
ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il],
1351613534
kv_lora_rank + n_embd_head_qk_rope, n_kv,

0 commit comments

Comments
 (0)