Skip to content

Commit 2f2fd15

Browse files
committed
Revived PR
1 parent 7d51644 commit 2f2fd15

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

src/llama-graph.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,11 @@ ggml_tensor * llm_graph_context::build_attn(
12571257
// store to KV cache
12581258
{
12591259
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1260-
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1260+
1261+
// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
1262+
if (!v_mla || !cparams.flash_attn) {
1263+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1264+
}
12611265
}
12621266

12631267
const auto & kq_mask = inp->get_kq_mask();
@@ -1341,7 +1345,11 @@ ggml_tensor * llm_graph_context::build_attn(
13411345
// store to KV cache
13421346
{
13431347
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1344-
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1348+
1349+
// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
1350+
if (!v_mla || !cparams.flash_attn) {
1351+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1352+
}
13451353
}
13461354

13471355
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();

src/llama-kv-cache-unified.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6262

6363
cells.resize(kv_size);
6464

65+
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
66+
6567
for (uint32_t il = 0; il < hparams.n_layer; il++) {
6668
if (filter && !filter(il)) {
6769
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
@@ -93,7 +95,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
9395
ggml_tensor * v;
9496

9597
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
96-
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
98+
99+
// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
100+
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, !is_mla || v_trans ? kv_size : 0);
97101

98102
ggml_format_name(k, "cache_k_l%d", il);
99103
ggml_format_name(v, "cache_v_l%d", il);
@@ -700,7 +704,9 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
700704
}
701705

702706
bool llama_kv_cache_unified::get_can_shift() const {
703-
return true;
707+
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
708+
709+
return !is_mla || v_trans; // TODO: allow context shifting for MLA with flash attention;
704710
}
705711

706712
uint32_t llama_kv_cache_unified::get_size() const {
@@ -733,12 +739,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
733739
auto * v = layers[ikv].v;
734740

735741
if (!v_trans) {
736-
// note: v->nb[1] <= v->nb[2]
737-
return ggml_view_3d(ctx, v,
738-
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
739-
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
740-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
741-
0);
742+
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
743+
744+
if (!is_mla) {
745+
// note: v->nb[1] <= v->nb[2]
746+
return ggml_view_3d(ctx, v,
747+
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
748+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
749+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
750+
0);
751+
} else {
752+
auto * k = layers[ikv].k;
753+
754+
// note: v->nb[1] == v->nb[2] for MLA as transforms into MQA
755+
return ggml_view_3d(ctx, k,
756+
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
757+
ggml_row_size(k->type, hparams.n_embd_head_k), // v->nb[1]
758+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), // v->nb[2]
759+
hparams.n_embd_head_k - hparams.n_embd_head_v); // offset by n_rot elements
760+
}
742761
}
743762

744763
// note: v->nb[1] > v->nb[2]

0 commit comments

Comments
 (0)