|
6 | 6 | #include "common.h" |
7 | 7 | #include "sampling.h" |
8 | 8 | #include "../src/llama-graph.h" |
| 9 | +#include "../src/llama-context.h" |
9 | 10 |
|
10 | 11 | #include <cstring> |
11 | 12 | #include <algorithm> |
@@ -362,126 +363,40 @@ llama_tokens common_speculative_gen_draft( |
362 | 363 | } |
363 | 364 |
|
364 | 365 |
|
365 | | -llama_tokens mtp_speculative_gen_draft( |
366 | | - struct common_sampler * smpl, |
367 | | - struct llama_context * ctx, |
368 | | - llama_token id_last, |
369 | | - int32_t n_past, |
370 | | - int32_t last_tok_idx) { |
| 366 | +llama_token mtp_speculative_gen_draft( |
| 367 | + struct common_sampler* smpl, |
| 368 | + struct llama_context* ctx, |
| 369 | + llama_token id_last, |
| 370 | + int32_t n_past, |
| 371 | + int32_t last_tok_idx) { |
371 | 372 |
|
372 | | - llama_tokens result; |
373 | | - |
374 | | - LOG_INF("step: '%d'\n", 1); |
375 | | - |
376 | | - // sample one token from the draft model -- this does NOT generalize to >1 MTP head |
377 | | - result.reserve(1); |
378 | | - |
379 | | - // need to determine which architecture we're using so we call the correct MTP model |
380 | 373 | const auto * model = llama_get_model(ctx); |
381 | | - |
382 | | - LOG_INF("step: '%d'\n", 2); |
383 | | - |
384 | | - //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
385 | | - //auto * gf = model.build_graph(gparams); |
386 | | - |
387 | | - LOG_INF("step: '%d'\n", 3); |
388 | | - |
389 | | - /*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
390 | | - LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
391 | | - ret = GGML_STATUS_ALLOC_FAILED; |
392 | | - return nullptr; |
393 | | - }*/ |
394 | | - |
395 | | - //llm_graph_result res_mtp(ctx->graph_max_nodes()); |
396 | | - llm_graph_result * res_mtp; |
397 | | - llama_ubatch ubatch_mtp; |
398 | | - ubatch_mtp.n_tokens = 1; |
399 | | - ubatch_mtp.pos = &n_past; // Critical for positional encoding |
400 | | - |
401 | | - // We also need a minimal ubatch to provide positional context (RoPE) |
402 | | - // ubatch_mtp.tokens = &last_token_id; |
403 | | - // ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper |
404 | | - // ubatch_mtp.logits = nullptr; |
405 | | - // ubatch_mtp.all_pos_0 = -1; |
406 | | - // ubatch_mtp.all_pos_1 = -1; |
407 | | - // ubatch_mtp.all_seq_id = -1; |
408 | | - |
409 | | - // Manually construct the graph parameters |
410 | | - //const llm_graph_params params_mtp = { |
411 | | - // /*.arch =*/ model->arch, |
412 | | - // /*.hparams =*/ model->hparams, |
413 | | - // /*.cparams =*/ ctx->cparams, |
414 | | - // /*.ubatch =*/ ubatch_mtp, |
415 | | - // /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, |
416 | | - // /*.sched =*/ ctx->sched.get(), |
417 | | - // /*.backend_cpu =*/ ctx->backend_cpu, |
418 | | - // /*.cvec =*/ &ctx->cvec, |
419 | | - // /*.loras =*/ &ctx->loras, |
420 | | - // /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context |
421 | | - // /*.cross =*/ &ctx->cross, |
422 | | - // /*.n_outputs =*/ 1, |
423 | | - // /*.cb =*/ ctx->graph_get_cb(), |
424 | | - // /*.res =*/ &res_mtp, // Point to our temporary result object |
425 | | - //}; |
426 | | - llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp); |
427 | | - |
428 | | - LOG_INF("step: '%d'\n", 4); |
429 | | - |
430 | | - // ggml_cgraph* build_mtp_graph(const llm_graph_params & params, |
431 | | - // ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; |
432 | 374 | auto * last_embd = llama_get_embeddings_tensor(ctx); |
433 | 375 |
|
434 | | - LOG_INF("step: '%d'\n", 5); |
435 | | - |
436 | 376 | GGML_ASSERT(model != nullptr); |
437 | 377 | GGML_ASSERT(last_embd != nullptr); |
| 378 | + llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx); |
438 | 379 |
|
439 | | - auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past); |
440 | | - |
441 | | - if (!gf) { |
442 | | - LOG_INF("%s: failed to initialize graph\n", __func__); |
443 | | - //ret = GGML_STATUS_FAILED; |
444 | | - return result; |
445 | | - } |
446 | | - |
447 | | - LOG_INF("step: '%d'\n", 6); |
448 | | - |
449 | | - const auto status = llama_graph_compute(ctx, gf, false); |
450 | | - |
451 | | - LOG_INF("step: '%d'\n", 7); |
452 | | - |
453 | | - struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp); |
454 | | - float * ctx_logit_pointer = llama_get_logits(ctx); |
| 380 | + common_sampler_sample(smpl, ctx, last_tok_idx, true); |
455 | 381 |
|
456 | | - LOG_INF("step: '%d'\n", 8); |
| 382 | + const auto* cur_p = common_sampler_get_candidates(smpl); |
| 383 | + /*LOG_INF("cur_p->size: %d\n", cur_p->size); |
457 | 384 |
|
458 | | - if (logits_mtp) { |
459 | | - llama_set_logits(ctx, logits_mtp); |
460 | | - } |
461 | | - |
462 | | - LOG_INF("step: '%d'\n", 9); |
463 | | - |
464 | | - { |
465 | | - common_sampler_sample(smpl, ctx, last_tok_idx, true); |
466 | | - |
467 | | - LOG_INF("step: '%d'\n", 10); |
468 | | - |
469 | | - const auto * cur_p = common_sampler_get_candidates(smpl); |
470 | | - |
471 | | - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { |
472 | | - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", |
473 | | - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); |
474 | | - } |
475 | | - |
476 | | - // add drafted token for each sequence |
477 | | - const llama_token id = cur_p->data[0].id; |
| 385 | + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { |
| 386 | + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", |
| 387 | + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); |
| 388 | + }*/ |
478 | 389 |
|
479 | | - // skip accepting draft token -- since we're only drafting one token this can't affect future outputs |
480 | | - // smpl will accept the token if it doesn't get rejected by main model later |
481 | | - // common_sampler_accept(smpl, id, true); |
| 390 | + // add drafted token for each sequence |
| 391 | + const llama_token id = cur_p->data[0].id; |
482 | 392 |
|
483 | | - result.push_back(id); |
484 | | - } |
| 393 | + // skip accepting draft token -- since we're only drafting one token this can't affect future outputs |
| 394 | + // smpl will accept the token if it doesn't get rejected by main model later |
| 395 | + // common_sampler_accept(smpl, id, true); |
485 | 396 |
|
486 | | - return result; |
| 397 | + //llama_tokens result; |
| 398 | + //result.reserve(1); |
| 399 | + //result.push_back(id); |
| 400 | + //return result; |
| 401 | + return id; |
487 | 402 | } |
0 commit comments