Skip to content

Commit cf0f7c0

Browse files
committed
broad thrust of the mtp implementation
1 parent 03231da commit cf0f7c0

File tree

9 files changed

+260
-11
lines changed

9 files changed

+260
-11
lines changed

common/speculative.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "log.h"
66
#include "common.h"
77
#include "sampling.h"
8+
#include "../src/llama-graph.h"
89

910
#include <cstring>
1011
#include <algorithm>
@@ -359,3 +360,128 @@ llama_tokens common_speculative_gen_draft(
359360
}
360361
return result;
361362
}
363+
364+
365+
llama_tokens mtp_speculative_gen_draft(
366+
struct common_sampler * smpl,
367+
struct llama_context * ctx,
368+
llama_token id_last,
369+
int32_t n_past,
370+
int32_t last_tok_idx) {
371+
372+
llama_tokens result;
373+
374+
LOG_INF("step: '%d'\n", 1);
375+
376+
// sample one token from the draft model -- this does NOT generalize to >1 MTP head
377+
result.reserve(1);
378+
379+
// need to determine which architecture we're using so we call the correct MTP model
380+
const auto * model = llama_get_model(ctx);
381+
382+
LOG_INF("step: '%d'\n", 2);
383+
384+
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
385+
//auto * gf = model.build_graph(gparams);
386+
387+
LOG_INF("step: '%d'\n", 3);
388+
389+
/*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
390+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
391+
ret = GGML_STATUS_ALLOC_FAILED;
392+
return nullptr;
393+
}*/
394+
395+
//llm_graph_result res_mtp(ctx->graph_max_nodes());
396+
llm_graph_result * res_mtp;
397+
llama_ubatch ubatch_mtp;
398+
ubatch_mtp.n_tokens = 1;
399+
ubatch_mtp.pos = &n_past; // Critical for positional encoding
400+
401+
// We also need a minimal ubatch to provide positional context (RoPE)
402+
// ubatch_mtp.tokens = &last_token_id;
403+
// ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper
404+
// ubatch_mtp.logits = nullptr;
405+
// ubatch_mtp.all_pos_0 = -1;
406+
// ubatch_mtp.all_pos_1 = -1;
407+
// ubatch_mtp.all_seq_id = -1;
408+
409+
// Manually construct the graph parameters
410+
//const llm_graph_params params_mtp = {
411+
// /*.arch =*/ model->arch,
412+
// /*.hparams =*/ model->hparams,
413+
// /*.cparams =*/ ctx->cparams,
414+
// /*.ubatch =*/ ubatch_mtp,
415+
// /*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
416+
// /*.sched =*/ ctx->sched.get(),
417+
// /*.backend_cpu =*/ ctx->backend_cpu,
418+
// /*.cvec =*/ &ctx->cvec,
419+
// /*.loras =*/ &ctx->loras,
420+
// /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context
421+
// /*.cross =*/ &ctx->cross,
422+
// /*.n_outputs =*/ 1,
423+
// /*.cb =*/ ctx->graph_get_cb(),
424+
// /*.res =*/ &res_mtp, // Point to our temporary result object
425+
//};
426+
llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp);
427+
428+
LOG_INF("step: '%d'\n", 4);
429+
430+
// ggml_cgraph* build_mtp_graph(const llm_graph_params & params,
431+
// ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
432+
auto * last_embd = llama_get_embeddings_tensor(ctx);
433+
434+
LOG_INF("step: '%d'\n", 5);
435+
436+
GGML_ASSERT(model != nullptr);
437+
GGML_ASSERT(last_embd != nullptr);
438+
439+
auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past);
440+
441+
if (!gf) {
442+
LOG_INF("%s: failed to initialize graph\n", __func__);
443+
//ret = GGML_STATUS_FAILED;
444+
return result;
445+
}
446+
447+
LOG_INF("step: '%d'\n", 6);
448+
449+
const auto status = llama_graph_compute(ctx, gf, false);
450+
451+
LOG_INF("step: '%d'\n", 7);
452+
453+
struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp);
454+
float * ctx_logit_pointer = llama_get_logits(ctx);
455+
456+
LOG_INF("step: '%d'\n", 8);
457+
458+
if (logits_mtp) {
459+
llama_set_logits(ctx, logits_mtp);
460+
}
461+
462+
LOG_INF("step: '%d'\n", 9);
463+
464+
{
465+
common_sampler_sample(smpl, ctx, last_tok_idx, true);
466+
467+
LOG_INF("step: '%d'\n", 10);
468+
469+
const auto * cur_p = common_sampler_get_candidates(smpl);
470+
471+
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
472+
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
473+
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
474+
}
475+
476+
// add drafted token for each sequence
477+
const llama_token id = cur_p->data[0].id;
478+
479+
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
480+
// smpl will accept the token if it doesn't get rejected by main model later
481+
// common_sampler_accept(smpl, id, true);
482+
483+
result.push_back(id);
484+
}
485+
486+
return result;
487+
}

common/speculative.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ void common_speculative_add_replacement_tgt_dft(
2727
struct common_speculative * spec,
2828
const char *source, const char *dest);
2929

30+
31+
// sample up to n_draft tokens and add them to the batch using the draft model
32+
llama_tokens mtp_speculative_gen_draft(
33+
struct common_sampler* smpl,
34+
struct llama_context* ctx,
35+
llama_token id_last,
36+
int32_t n_past,
37+
int32_t last_tok_idx);
38+
3039
// sample up to n_draft tokens and add them to the batch using the draft model
3140
llama_tokens common_speculative_gen_draft(
3241
struct common_speculative * spec,

include/llama.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,12 +544,17 @@ extern "C" {
544544
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
545545
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
546546

547+
LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params,
548+
struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past);
549+
547550
// Returns 0 on success
548551
LLAMA_API uint32_t llama_model_quantize(
549552
const char * fname_inp,
550553
const char * fname_out,
551554
const llama_model_quantize_params * params);
552555

556+
557+
553558
//
554559
// Adapters
555560
//
@@ -972,6 +977,8 @@ extern "C" {
972977
// returns NULL for invalid ids.
973978
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
974979

980+
LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override);
981+
975982
// Get all output token embeddings.
976983
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
977984
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
@@ -994,6 +1001,8 @@ extern "C" {
9941001
// otherwise: float[n_embd] (1-dimensional)
9951002
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
9961003

1004+
LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx);
1005+
9971006
//
9981007
// Vocab
9991008
//
@@ -1452,6 +1461,14 @@ extern "C" {
14521461
ggml_opt_epoch_callback callback_train,
14531462
ggml_opt_epoch_callback callback_eval);
14541463

1464+
LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch);
1465+
1466+
LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);
1467+
1468+
LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);
1469+
1470+
1471+
14551472
#ifdef __cplusplus
14561473
}
14571474
#endif

src/llama-context.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "llama-memory.h"
77
#include "llama-mmap.h"
88
#include "llama-model.h"
9+
#include "llama-graph.h"
910

1011
#include <cinttypes>
1112
#include <cstring>
@@ -522,6 +523,14 @@ float * llama_context::get_logits() {
522523
return logits;
523524
}
524525

526+
void llama_context::set_logits(struct ggml_tensor * logit_override) {
527+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override);
528+
GGML_ASSERT(backend_res != nullptr);
529+
GGML_ASSERT(logits != nullptr);
530+
531+
ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float));
532+
}
533+
525534
float * llama_context::get_logits_ith(int32_t i) {
526535
int64_t j = -1;
527536

@@ -617,6 +626,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
617626
return it->second.data();
618627
}
619628

629+
ggml_tensor * llama_context::get_embeddings_tensor() {
630+
return embd_tensor;
631+
}
632+
620633
void llama_context::attach_threadpool(
621634
ggml_threadpool_t threadpool,
622635
ggml_threadpool_t threadpool_batch) {
@@ -1113,6 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11131126

11141127
auto * t_logits = res->get_logits();
11151128
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1129+
embd_tensor = res->get_embd();
11161130

11171131
if (t_embd && res->get_embd_pooled()) {
11181132
t_embd = res->get_embd_pooled();
@@ -1429,6 +1443,27 @@ llm_graph_params llama_context::graph_params(
14291443
};
14301444
}
14311445

1446+
llm_graph_params llama_context::mtp_graph_params(
1447+
llm_graph_result* res,
1448+
const llama_ubatch& ubatch) const {
1449+
return {
1450+
/*.arch =*/ model.arch,
1451+
/*.hparams =*/ model.hparams,
1452+
/*.cparams =*/ cparams,
1453+
/*.ubatch =*/ ubatch,
1454+
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
1455+
/*.sched =*/ sched.get(),
1456+
/*.backend_cpu =*/ backend_cpu,
1457+
/*.cvec =*/ &cvec,
1458+
/*.loras =*/ &loras,
1459+
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
1460+
/*.cross =*/ &cross,
1461+
/*.n_outputs =*/ 1,
1462+
/*.cb =*/ graph_get_cb(),
1463+
/*.res =*/ res,
1464+
};
1465+
}
1466+
14321467
ggml_status llama_context::graph_compute(
14331468
ggml_cgraph * gf,
14341469
bool batched) {
@@ -2233,6 +2268,7 @@ void llama_context::opt_epoch(
22332268
llama_batch_free(batch);
22342269
}
22352270

2271+
22362272
//
22372273
// interface implementation
22382274
//
@@ -2274,6 +2310,8 @@ llama_context_params llama_context_default_params() {
22742310
return result;
22752311
}
22762312

2313+
2314+
22772315
llama_context * llama_init_from_model(
22782316
llama_model * model,
22792317
llama_context_params params) {
@@ -2412,6 +2450,11 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
24122450
return ctx->get_logits_ith(i);
24132451
}
24142452

2453+
void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) {
2454+
ctx->set_logits(logit_override);
2455+
}
2456+
2457+
24152458
float * llama_get_embeddings(llama_context * ctx) {
24162459
ctx->synchronize();
24172460

@@ -2430,6 +2473,13 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
24302473
return ctx->get_embeddings_seq(seq_id);
24312474
}
24322475

2476+
ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) {
2477+
ctx->synchronize();
2478+
2479+
return ctx->get_embeddings_tensor();
2480+
}
2481+
2482+
24332483
// llama adapter API
24342484

24352485
int32_t llama_set_adapter_lora(
@@ -2926,3 +2976,12 @@ void llama_opt_epoch(
29262976
callback_train,
29272977
callback_eval);
29282978
}
2979+
2980+
llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) {
2981+
return ctx->mtp_graph_params(res, ubatch);
2982+
}
2983+
2984+
2985+
ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
2986+
return ctx->graph_compute(gf, batched);
2987+
}

src/llama-context.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct llama_context {
5959
float * get_embeddings();
6060
float * get_embeddings_ith(int32_t i);
6161
float * get_embeddings_seq(llama_seq_id seq_id);
62+
ggml_tensor * get_embeddings_tensor();
6263

6364
void attach_threadpool(
6465
ggml_threadpool_t threadpool,
@@ -199,6 +200,10 @@ struct llama_context {
199200
// reserve a graph with a dummy ubatch of the specified size
200201
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
201202

203+
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const;
204+
205+
void set_logits(struct ggml_tensor* logit_override);
206+
202207
private:
203208
llm_graph_params graph_params(
204209
llm_graph_result * res,
@@ -240,6 +245,7 @@ struct llama_context {
240245
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
241246
size_t embd_size = 0; // capacity (of floats) for embeddings
242247
float * embd = nullptr;
248+
ggml_tensor * embd_tensor = nullptr;
243249

244250
// sequence embeddings output (map of [n_embd] vectors)
245251
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
@@ -308,3 +314,4 @@ struct llama_context {
308314

309315
mutable int32_t n_reused = 0; // number of times the previous graph was reused
310316
};
317+

src/llama-graph.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,3 +1911,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
19111911

19121912
return relative_bucket;
19131913
}
1914+
1915+
ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) {
1916+
return res->get_logits();
1917+
}

src/llama-graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,4 @@ struct llm_graph_context {
818818

819819
// TODO: better name
820820
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
821+

0 commit comments

Comments
 (0)