|
5 | 5 | #include "log.h" |
6 | 6 | #include "common.h" |
7 | 7 | #include "sampling.h" |
| 8 | +#include "../src/llama-graph.h" |
8 | 9 |
|
9 | 10 | #include <cstring> |
10 | 11 | #include <algorithm> |
@@ -359,3 +360,128 @@ llama_tokens common_speculative_gen_draft( |
359 | 360 | } |
360 | 361 | return result; |
361 | 362 | } |
| 363 | + |
| 364 | + |
| 365 | +llama_tokens mtp_speculative_gen_draft( |
| 366 | + struct common_sampler * smpl, |
| 367 | + struct llama_context * ctx, |
| 368 | + llama_token id_last, |
| 369 | + int32_t n_past, |
| 370 | + int32_t last_tok_idx) { |
| 371 | + |
| 372 | + llama_tokens result; |
| 373 | + |
| 374 | + LOG_INF("step: '%d'\n", 1); |
| 375 | + |
| 376 | + // sample one token from the draft model -- this does NOT generalize to >1 MTP head |
| 377 | + result.reserve(1); |
| 378 | + |
| 379 | + // need to determine which architecture we're using so we call the correct MTP model |
| 380 | + const auto * model = llama_get_model(ctx); |
| 381 | + |
| 382 | + LOG_INF("step: '%d'\n", 2); |
| 383 | + |
| 384 | + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
| 385 | + //auto * gf = model.build_graph(gparams); |
| 386 | + |
| 387 | + LOG_INF("step: '%d'\n", 3); |
| 388 | + |
| 389 | + /*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
| 390 | + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
| 391 | + ret = GGML_STATUS_ALLOC_FAILED; |
| 392 | + return nullptr; |
| 393 | + }*/ |
| 394 | + |
| 395 | + //llm_graph_result res_mtp(ctx->graph_max_nodes()); |
| 396 | + llm_graph_result * res_mtp; |
| 397 | + llama_ubatch ubatch_mtp; |
| 398 | + ubatch_mtp.n_tokens = 1; |
| 399 | + ubatch_mtp.pos = &n_past; // Critical for positional encoding |
| 400 | + |
| 401 | + // We also need a minimal ubatch to provide positional context (RoPE) |
| 402 | + // ubatch_mtp.tokens = &last_token_id; |
| 403 | + // ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper |
| 404 | + // ubatch_mtp.logits = nullptr; |
| 405 | + // ubatch_mtp.all_pos_0 = -1; |
| 406 | + // ubatch_mtp.all_pos_1 = -1; |
| 407 | + // ubatch_mtp.all_seq_id = -1; |
| 408 | + |
| 409 | + // Manually construct the graph parameters |
| 410 | + //const llm_graph_params params_mtp = { |
| 411 | + // /*.arch =*/ model->arch, |
| 412 | + // /*.hparams =*/ model->hparams, |
| 413 | + // /*.cparams =*/ ctx->cparams, |
| 414 | + // /*.ubatch =*/ ubatch_mtp, |
| 415 | + // /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, |
| 416 | + // /*.sched =*/ ctx->sched.get(), |
| 417 | + // /*.backend_cpu =*/ ctx->backend_cpu, |
| 418 | + // /*.cvec =*/ &ctx->cvec, |
| 419 | + // /*.loras =*/ &ctx->loras, |
| 420 | + // /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context |
| 421 | + // /*.cross =*/ &ctx->cross, |
| 422 | + // /*.n_outputs =*/ 1, |
| 423 | + // /*.cb =*/ ctx->graph_get_cb(), |
| 424 | + // /*.res =*/ &res_mtp, // Point to our temporary result object |
| 425 | + //}; |
| 426 | + llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp); |
| 427 | + |
| 428 | + LOG_INF("step: '%d'\n", 4); |
| 429 | + |
| 430 | + // ggml_cgraph* build_mtp_graph(const llm_graph_params & params, |
| 431 | + // ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; |
| 432 | + auto * last_embd = llama_get_embeddings_tensor(ctx); |
| 433 | + |
| 434 | + LOG_INF("step: '%d'\n", 5); |
| 435 | + |
| 436 | + GGML_ASSERT(model != nullptr); |
| 437 | + GGML_ASSERT(last_embd != nullptr); |
| 438 | + |
| 439 | + auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past); |
| 440 | + |
| 441 | + if (!gf) { |
| 442 | + LOG_INF("%s: failed to initialize graph\n", __func__); |
| 443 | + //ret = GGML_STATUS_FAILED; |
| 444 | + return result; |
| 445 | + } |
| 446 | + |
| 447 | + LOG_INF("step: '%d'\n", 6); |
| 448 | + |
| 449 | + const auto status = llama_graph_compute(ctx, gf, false); |
| 450 | + |
| 451 | + LOG_INF("step: '%d'\n", 7); |
| 452 | + |
| 453 | + struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp); |
| 454 | + float * ctx_logit_pointer = llama_get_logits(ctx); |
| 455 | + |
| 456 | + LOG_INF("step: '%d'\n", 8); |
| 457 | + |
| 458 | + if (logits_mtp) { |
| 459 | + llama_set_logits(ctx, logits_mtp); |
| 460 | + } |
| 461 | + |
| 462 | + LOG_INF("step: '%d'\n", 9); |
| 463 | + |
| 464 | + { |
| 465 | + common_sampler_sample(smpl, ctx, last_tok_idx, true); |
| 466 | + |
| 467 | + LOG_INF("step: '%d'\n", 10); |
| 468 | + |
| 469 | + const auto * cur_p = common_sampler_get_candidates(smpl); |
| 470 | + |
| 471 | + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { |
| 472 | + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", |
| 473 | + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); |
| 474 | + } |
| 475 | + |
| 476 | + // add drafted token for each sequence |
| 477 | + const llama_token id = cur_p->data[0].id; |
| 478 | + |
| 479 | + // skip accepting draft token -- since we're only drafting one token this can't affect future outputs |
| 480 | + // smpl will accept the token if it doesn't get rejected by main model later |
| 481 | + // common_sampler_accept(smpl, id, true); |
| 482 | + |
| 483 | + result.push_back(id); |
| 484 | + } |
| 485 | + |
| 486 | + return result; |
| 487 | +} |
0 commit comments