Skip to content

Commit 6e9bafc

Browse files
committed
failed attempt to implement MTP; outputs tokens but KV cache management is unreasonable
1 parent cf0f7c0 commit 6e9bafc

File tree

8 files changed

+141
-147
lines changed

8 files changed

+141
-147
lines changed

common/sampling.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
348348

349349
llama_sampler_apply(chain, &cur_p);
350350

351+
/*for (int k = 0; k < (int)cur_p.size; ++k) {
352+
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n",
353+
k, 0, cur_p.data[k].id, cur_p.data[k].p);
354+
}*/
355+
351356
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
352357

353358
const llama_token id = cur_p.data[cur_p.selected].id;

common/speculative.cpp

Lines changed: 25 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "common.h"
77
#include "sampling.h"
88
#include "../src/llama-graph.h"
9+
#include "../src/llama-context.h"
910

1011
#include <cstring>
1112
#include <algorithm>
@@ -362,126 +363,40 @@ llama_tokens common_speculative_gen_draft(
362363
}
363364

364365

365-
llama_tokens mtp_speculative_gen_draft(
366-
struct common_sampler * smpl,
367-
struct llama_context * ctx,
368-
llama_token id_last,
369-
int32_t n_past,
370-
int32_t last_tok_idx) {
366+
llama_token mtp_speculative_gen_draft(
367+
struct common_sampler* smpl,
368+
struct llama_context* ctx,
369+
llama_token id_last,
370+
int32_t n_past,
371+
int32_t last_tok_idx) {
371372

372-
llama_tokens result;
373-
374-
LOG_INF("step: '%d'\n", 1);
375-
376-
// sample one token from the draft model -- this does NOT generalize to >1 MTP head
377-
result.reserve(1);
378-
379-
// need to determine which architecture we're using so we call the correct MTP model
380373
const auto * model = llama_get_model(ctx);
381-
382-
LOG_INF("step: '%d'\n", 2);
383-
384-
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
385-
//auto * gf = model.build_graph(gparams);
386-
387-
LOG_INF("step: '%d'\n", 3);
388-
389-
/*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
390-
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
391-
ret = GGML_STATUS_ALLOC_FAILED;
392-
return nullptr;
393-
}*/
394-
395-
//llm_graph_result res_mtp(ctx->graph_max_nodes());
396-
llm_graph_result * res_mtp;
397-
llama_ubatch ubatch_mtp;
398-
ubatch_mtp.n_tokens = 1;
399-
ubatch_mtp.pos = &n_past; // Critical for positional encoding
400-
401-
// We also need a minimal ubatch to provide positional context (RoPE)
402-
// ubatch_mtp.tokens = &last_token_id;
403-
// ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper
404-
// ubatch_mtp.logits = nullptr;
405-
// ubatch_mtp.all_pos_0 = -1;
406-
// ubatch_mtp.all_pos_1 = -1;
407-
// ubatch_mtp.all_seq_id = -1;
408-
409-
// Manually construct the graph parameters
410-
//const llm_graph_params params_mtp = {
411-
// /*.arch =*/ model->arch,
412-
// /*.hparams =*/ model->hparams,
413-
// /*.cparams =*/ ctx->cparams,
414-
// /*.ubatch =*/ ubatch_mtp,
415-
// /*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
416-
// /*.sched =*/ ctx->sched.get(),
417-
// /*.backend_cpu =*/ ctx->backend_cpu,
418-
// /*.cvec =*/ &ctx->cvec,
419-
// /*.loras =*/ &ctx->loras,
420-
// /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context
421-
// /*.cross =*/ &ctx->cross,
422-
// /*.n_outputs =*/ 1,
423-
// /*.cb =*/ ctx->graph_get_cb(),
424-
// /*.res =*/ &res_mtp, // Point to our temporary result object
425-
//};
426-
llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp);
427-
428-
LOG_INF("step: '%d'\n", 4);
429-
430-
// ggml_cgraph* build_mtp_graph(const llm_graph_params & params,
431-
// ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
432374
auto * last_embd = llama_get_embeddings_tensor(ctx);
433375

434-
LOG_INF("step: '%d'\n", 5);
435-
436376
GGML_ASSERT(model != nullptr);
437377
GGML_ASSERT(last_embd != nullptr);
378+
llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx);
438379

439-
auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past);
440-
441-
if (!gf) {
442-
LOG_INF("%s: failed to initialize graph\n", __func__);
443-
//ret = GGML_STATUS_FAILED;
444-
return result;
445-
}
446-
447-
LOG_INF("step: '%d'\n", 6);
448-
449-
const auto status = llama_graph_compute(ctx, gf, false);
450-
451-
LOG_INF("step: '%d'\n", 7);
452-
453-
struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp);
454-
float * ctx_logit_pointer = llama_get_logits(ctx);
380+
common_sampler_sample(smpl, ctx, last_tok_idx, true);
455381

456-
LOG_INF("step: '%d'\n", 8);
382+
const auto* cur_p = common_sampler_get_candidates(smpl);
383+
/*LOG_INF("cur_p->size: %d\n", cur_p->size);
457384
458-
if (logits_mtp) {
459-
llama_set_logits(ctx, logits_mtp);
460-
}
461-
462-
LOG_INF("step: '%d'\n", 9);
463-
464-
{
465-
common_sampler_sample(smpl, ctx, last_tok_idx, true);
466-
467-
LOG_INF("step: '%d'\n", 10);
468-
469-
const auto * cur_p = common_sampler_get_candidates(smpl);
470-
471-
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
472-
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
473-
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
474-
}
475-
476-
// add drafted token for each sequence
477-
const llama_token id = cur_p->data[0].id;
385+
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
386+
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
387+
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
388+
}*/
478389

479-
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
480-
// smpl will accept the token if it doesn't get rejected by main model later
481-
// common_sampler_accept(smpl, id, true);
390+
// add drafted token for each sequence
391+
const llama_token id = cur_p->data[0].id;
482392

483-
result.push_back(id);
484-
}
393+
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
394+
// smpl will accept the token if it doesn't get rejected by main model later
395+
// common_sampler_accept(smpl, id, true);
485396

486-
return result;
397+
//llama_tokens result;
398+
//result.reserve(1);
399+
//result.push_back(id);
400+
//return result;
401+
return id;
487402
}

common/speculative.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void common_speculative_add_replacement_tgt_dft(
2929

3030

3131
// sample up to n_draft tokens and add them to the batch using the draft model
32-
llama_tokens mtp_speculative_gen_draft(
32+
llama_token mtp_speculative_gen_draft(
3333
struct common_sampler* smpl,
3434
struct llama_context* ctx,
3535
llama_token id_last,

include/llama.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -977,8 +977,6 @@ extern "C" {
977977
// returns NULL for invalid ids.
978978
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
979979

980-
LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override);
981-
982980
// Get all output token embeddings.
983981
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
984982
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
@@ -1465,6 +1463,9 @@ extern "C" {
14651463

14661464
LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);
14671465

1466+
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
1467+
ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
1468+
14681469
LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);
14691470

14701471

src/llama-context.cpp

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,16 @@ float * llama_context::get_logits() {
523523
return logits;
524524
}
525525

526-
void llama_context::set_logits(struct ggml_tensor * logit_override) {
527-
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override);
526+
void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) {
527+
output_reorder();
528+
529+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override);
528530
GGML_ASSERT(backend_res != nullptr);
529531
GGML_ASSERT(logits != nullptr);
530532

531-
ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float));
533+
int64_t j = output_ids[i];
534+
535+
ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float));
532536
}
533537

534538
float * llama_context::get_logits_ith(int32_t i) {
@@ -1445,21 +1449,23 @@ llm_graph_params llama_context::graph_params(
14451449

14461450
llm_graph_params llama_context::mtp_graph_params(
14471451
llm_graph_result* res,
1448-
const llama_ubatch& ubatch) const {
1452+
const llama_ubatch& ubatch) {
1453+
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
1454+
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
14491455
return {
14501456
/*.arch =*/ model.arch,
14511457
/*.hparams =*/ model.hparams,
14521458
/*.cparams =*/ cparams,
14531459
/*.ubatch =*/ ubatch,
14541460
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
1455-
/*.sched =*/ sched.get(),
1461+
/*.sched =*/ temp_sched,
14561462
/*.backend_cpu =*/ backend_cpu,
14571463
/*.cvec =*/ &cvec,
14581464
/*.loras =*/ &loras,
14591465
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
14601466
/*.cross =*/ &cross,
14611467
/*.n_outputs =*/ 1,
1462-
/*.cb =*/ graph_get_cb(),
1468+
/*.cb =*/ graph_get_cb(temp_sched),
14631469
/*.res =*/ res,
14641470
};
14651471
}
@@ -1491,8 +1497,10 @@ ggml_status llama_context::graph_compute(
14911497
return status;
14921498
}
14931499

1494-
llm_graph_cb llama_context::graph_get_cb() const {
1495-
return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
1500+
llm_graph_cb llama_context::graph_get_cb(ggml_backend_sched * sched_override) const {
1501+
ggml_backend_sched * cb_sched = sched_override ? sched_override : sched.get();
1502+
1503+
return [=](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
14961504
if (il >= 0) {
14971505
ggml_format_name(cur, "%s-%d", name, il);
14981506
} else {
@@ -1502,7 +1510,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
15021510
if (!cparams.offload_kqv) {
15031511
if (strcmp(name, "kqv_merged_cont") == 0) {
15041512
// all nodes between the KV store and the attention output are run on the CPU
1505-
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
1513+
ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend_cpu);
15061514
}
15071515
}
15081516

@@ -1515,7 +1523,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
15151523
for (const auto & backend : backends) {
15161524
if (ggml_backend_get_device(backend.get()) == dev_layer) {
15171525
if (ggml_backend_supports_op(backend.get(), cur)) {
1518-
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
1526+
ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend.get());
15191527
}
15201528
}
15211529
}
@@ -1524,6 +1532,10 @@ llm_graph_cb llama_context::graph_get_cb() const {
15241532
};
15251533
}
15261534

1535+
ggml_backend_sched_t llama_context::create_temp_scheduler(size_t n_nodes) {
1536+
return ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), n_nodes, false, cparams.op_offload);
1537+
}
1538+
15271539
//
15281540
// state save/load
15291541
//
@@ -2450,10 +2462,6 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
24502462
return ctx->get_logits_ith(i);
24512463
}
24522464

2453-
void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) {
2454-
ctx->set_logits(logit_override);
2455-
}
2456-
24572465

24582466
float * llama_get_embeddings(llama_context * ctx) {
24592467
ctx->synchronize();
@@ -2985,3 +2993,37 @@ llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* re
29852993
ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
29862994
return ctx->graph_compute(gf, batched);
29872995
}
2996+
2997+
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
2998+
ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
2999+
3000+
const auto * model = llama_get_model(ctx);
3001+
3002+
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
3003+
3004+
llama_ubatch ubatch_mtp;
3005+
ubatch_mtp.n_tokens = 1;
3006+
ubatch_mtp.pos = &n_past;
3007+
3008+
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp));
3009+
3010+
auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past);
3011+
3012+
ggml_backend_sched_t sched = params_mtp->sched;
3013+
3014+
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
3015+
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
3016+
3017+
ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
3018+
3019+
ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
3020+
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
3021+
3022+
struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
3023+
LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
3024+
3025+
if (logits_mtp) {
3026+
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
3027+
}
3028+
}
3029+

src/llama-context.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,11 @@ struct llama_context {
200200
// reserve a graph with a dummy ubatch of the specified size
201201
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
202202

203-
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const;
203+
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch);
204204

205-
void set_logits(struct ggml_tensor* logit_override);
205+
void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);
206+
207+
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);
206208

207209
private:
208210
llm_graph_params graph_params(
@@ -211,7 +213,7 @@ struct llama_context {
211213
const llama_memory_context_i * mctx,
212214
llm_graph_type gtype) const;
213215

214-
llm_graph_cb graph_get_cb() const;
216+
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
215217

216218
// TODO: read/write lora adapters and cvec
217219
size_t state_write_data(llama_io_write_i & io);

0 commit comments

Comments
 (0)