Skip to content

Commit 5e1d719

Browse files
mtp-batch (feat): Create and manage sinfo for MTP
1 parent febd823 commit 5e1d719

File tree

8 files changed

+232
-65
lines changed

8 files changed

+232
-65
lines changed

common/speculative.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,35 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
418418
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
419419
mtp_batch.logits[i] = false;
420420
}
421-
422421
llama_decode(ctx, mtp_batch);
423422
}
424423

424+
void mtp_accept_tokens(
425+
struct llama_context * ctx,
426+
const std::vector<llama_token> & ids,
427+
int32_t n_past_base,
428+
llama_seq_id seq_id
429+
) {
430+
if (ids.empty()) {
431+
return;
432+
}
433+
434+
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
435+
return;
436+
}
437+
438+
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
439+
for (size_t i = 0; i < ids.size(); ++i) {
440+
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, false);
441+
}
442+
443+
mtp_update_kv_cache(ctx, accepted_batch, false);
444+
445+
llama_mtp_cancel_sinfo_update(ctx);
446+
447+
llama_batch_free(accepted_batch);
448+
}
449+
425450
// Debug function - It will be removed later
426451
double calculate_vector_sum_double(const float* vec, size_t size) {
427452
if (!vec) {

common/speculative.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,11 @@ llama_tokens common_speculative_gen_draft(
5151

5252
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
5353

54+
void mtp_accept_tokens(
55+
struct llama_context * ctx,
56+
const std::vector<llama_token> & ids,
57+
int32_t n_past_base,
58+
llama_seq_id seq_id
59+
);
60+
5461
double calculate_vector_sum_double(const float* vec, size_t size);

include/llama.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1457,7 +1457,11 @@ extern "C" {
14571457
ggml_opt_epoch_callback callback_train,
14581458
ggml_opt_epoch_callback callback_eval);
14591459

1460-
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
1460+
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
1461+
1462+
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
1463+
1464+
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
14611465

14621466
#ifdef __cplusplus
14631467
}

src/llama-context.cpp

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
//
1919
// llama_context
2020
//
21+
struct llama_context_kv_cache_data {
22+
llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos;
23+
llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force;
24+
const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr;
25+
};
2126

2227
llama_context::llama_context(
2328
const llama_model & model,
@@ -106,6 +111,8 @@ llama_context::llama_context(
106111
cparams.op_offload = params.op_offload;
107112
cparams.kv_unified = params.kv_unified;
108113

114+
kv_cache_data = new llama_context_kv_cache_data();
115+
109116
{
110117
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
111118
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
@@ -371,6 +378,7 @@ llama_context::llama_context(
371378

372379
llama_context::~llama_context() {
373380
ggml_opt_free(opt_ctx);
381+
delete static_cast<llama_context_kv_cache_data *>(kv_cache_data);
374382
}
375383

376384
void llama_context::synchronize() {
@@ -1017,6 +1025,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
10171025

10181026
int llama_context::decode(const llama_batch & batch_inp) {
10191027
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
1028+
1029+
auto * kvd = static_cast<llama_context_kv_cache_data *>(kv_cache_data);
10201030
LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n",
10211031
batch_inp.update_mtp_kv ? "true" : "false",
10221032
batch_inp.use_mtp_head ? "true" : "false"
@@ -1076,10 +1086,31 @@ int llama_context::decode(const llama_batch & batch_inp) {
10761086
// handle any pending defrags/shifts
10771087
kv_self_update(false);
10781088

1079-
llama_memory_context_ptr mctx;
1089+
std::unique_ptr<llama_memory_context_i> mctx;
10801090

10811091
while (true) {
1082-
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
1092+
if (cparams.warmup) {
1093+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
1094+
} else {
1095+
if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
1096+
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
1097+
1098+
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
1099+
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
1100+
);
1101+
} else {
1102+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
1103+
1104+
if (!batch_inp.use_mtp_head && !batch_inp.update_mtp_kv) {
1105+
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
1106+
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
1107+
} else {
1108+
kvd->last_main_model_sinfos.clear();
1109+
}
1110+
}
1111+
}
1112+
}
1113+
10831114
if (!mctx) {
10841115
return -2;
10851116
}
@@ -1091,29 +1122,28 @@ int llama_context::decode(const llama_batch & batch_inp) {
10911122
case LLAMA_MEMORY_STATUS_NO_UPDATE:
10921123
{
10931124
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
1094-
10951125
return -2;
10961126
}
10971127
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
10981128
{
1129+
// if (use_last_main_model_sinfos) {
1130+
// LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__);
1131+
// return -1;
1132+
// }
1133+
10991134
if (!did_optimize) {
11001135
did_optimize = true;
1101-
11021136
if (kv_self_update(true)) {
11031137
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
1104-
11051138
continue;
11061139
}
11071140
}
1108-
11091141
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
1110-
11111142
return 1;
11121143
}
11131144
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
11141145
{
11151146
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
1116-
11171147
return -2;
11181148
}
11191149
}
@@ -3073,4 +3103,27 @@ void llama_opt_epoch(
30733103

30743104
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
30753105
ctx->draft_input_hidden_state = hidden_state;
3106+
}
3107+
3108+
bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) {
3109+
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
3110+
const auto & last_sinfo = kvd->last_main_model_sinfos;
3111+
3112+
if (last_sinfo.empty() || last_sinfo[0].idxs.empty()) {
3113+
LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.", __func__);
3114+
return false;
3115+
}
3116+
3117+
kvd->resized_sinfo_for_force = last_sinfo;
3118+
3119+
kvd->resized_sinfo_for_force[0].idxs[0].resize(n_accepted);
3120+
3121+
kvd->forced_sinfos = &kvd->resized_sinfo_for_force;
3122+
3123+
return true;
3124+
}
3125+
3126+
void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) {
3127+
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
3128+
kvd->forced_sinfos = nullptr;
30763129
}

src/llama-context.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,20 @@ class llama_io_write_i;
2020
struct llama_memory_i;
2121
struct llama_memory_context_i;
2222

23+
struct llama_context_kv_cache_data;
24+
2325
struct llama_context {
2426
// init scheduler and compute buffers, reserve worst-case graphs
2527
llama_context(
2628
const llama_model & model,
2729
llama_context_params params);
2830

2931
~llama_context();
32+
33+
llama_context(const llama_context &) = delete;
34+
llama_context & operator=(const llama_context &) = delete;
35+
llama_context(llama_context &&) = delete;
36+
llama_context & operator=(llama_context &&) = delete;
3037

3138
void synchronize();
3239

@@ -211,6 +218,9 @@ struct llama_context {
211218

212219
std::unique_ptr<llama_memory_context_i> mtp_memory_batch(const llama_batch& batch_inp);
213220

221+
// For MTP KV cache cell reuse
222+
void * kv_cache_data;
223+
214224
private:
215225
llm_graph_params graph_params(
216226
llm_graph_result * res,

0 commit comments

Comments
 (0)