Skip to content

Commit d72f9d5

Browse files
committed
kludge-y kv cache management of mtp layer
1 parent 382135a commit d72f9d5

File tree

4 files changed

+32
-5
lines changed

4 files changed

+32
-5
lines changed

src/llama-context.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "llama-mmap.h"
88
#include "llama-model.h"
99
#include "llama-graph.h"
10+
#include "llama-kv-cache-unified.h"
1011

1112
#include <cinttypes>
1213
#include <cstring>
@@ -3000,7 +3001,20 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
30003001
const auto * model = llama_get_model(ctx);
30013002

30023003
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
3003-
llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp);
3004+
std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
3005+
3006+
std::vector<uint32_t> idxs;
3007+
idxs.push_back(n_past);
3008+
llama_kv_cache_unified::slot_info sinfo = {
3009+
/*.s0 =*/ 0,
3010+
/*.s1 =*/ 0,
3011+
/*.strm =*/ { 0 },
3012+
/*.idxs =*/ { idxs },
3013+
};
3014+
llama_kv_cache_unified::slot_info_vec_t sinfos;
3015+
sinfos.push_back(sinfo);
3016+
3017+
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
30043018
const auto& ubatch_mtp = mctx->get_ubatch();
30053019

30063020
//llama_ubatch ubatch_mtp;
@@ -3012,9 +3026,10 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
30123026

30133027
auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
30143028

3015-
if (mctx && !mctx->apply()) {
3016-
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
3017-
}
3029+
//if (mctx && !mctx->set_n_kv()) {
3030+
// LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
3031+
//}
3032+
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
30183033

30193034
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
30203035

src/llama-kv-cache-unified.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,6 +2322,11 @@ bool llama_kv_cache_unified_context::apply() {
23222322
return true;
23232323
}
23242324

2325+
void llama_kv_cache_unified_context::set_n_kv() {
2326+
n_kv = kv->get_n_kv();
2327+
}
2328+
2329+
23252330
llama_memory_status llama_kv_cache_unified_context::get_status() const {
23262331
return status;
23272332
}
@@ -2384,6 +2389,10 @@ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, con
23842389
kv->set_input_pos_bucket(dst, ubatch);
23852390
}
23862391

2392+
void llama_kv_cache_unified_context::set_sinfos(llama_kv_cache_unified::slot_info_vec_t new_sinfos) {
2393+
sinfos = new_sinfos;
2394+
}
2395+
23872396
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
23882397
// the FA kernels require padding to avoid extra runtime boundary checks
23892398
return cparams.flash_attn ? 256u : 32u;

src/llama-kv-cache-unified.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ class llama_kv_cache_unified_context : public llama_memory_context_i {
340340
//
341341

342342
uint32_t get_n_kv() const;
343+
void set_n_kv();
343344

344345
// TODO: temporary
345346
bool get_supports_set_rows() const;
@@ -362,6 +363,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i {
362363
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
363364
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
364365

366+
void set_sinfos(slot_info_vec_t new_sinfos);
367+
365368
private:
366369
llama_memory_status status;
367370

tools/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3545,7 +3545,7 @@ struct server_context {
35453545

35463546
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
35473547
slot.last_tok_idx = tok_idx;
3548-
SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
3548+
//SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
35493549

35503550
slot.i_batch = -1;
35513551

0 commit comments

Comments
 (0)