Skip to content

Commit 98bc0c6

Browse files
committed
replace standard sampler with greedy sampler for mtp draft
1 parent 471e026 commit 98bc0c6

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

common/speculative.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,10 @@ llama_token mtp_speculative_gen_draft(
387387
/*.logits = */ logits_data
388388
};
389389

390-
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
390+
return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
391391
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
392392

393+
/*
393394
if (!smpl) {
394395
return -1;
395396
}
@@ -405,6 +406,7 @@ llama_token mtp_speculative_gen_draft(
405406
const llama_token id = cur_p->data[0].id;
406407
return id;
407408
}
409+
*/
408410
// LOG_INF("cur_p->size: %d\n", cur_p->size);
409411

410412

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1454,7 +1454,7 @@ extern "C" {
14541454
ggml_opt_epoch_callback callback_train,
14551455
ggml_opt_epoch_callback callback_eval);
14561456

1457-
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
1457+
LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
14581458
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
14591459

14601460
#ifdef __cplusplus

src/llama-context.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,7 +2995,7 @@ void llama_opt_epoch(
29952995
callback_eval);
29962996
}
29972997

2998-
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
2998+
llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
29992999
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
30003000

30013001
const auto * model = llama_get_model(ctx);
@@ -3044,13 +3044,29 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
30443044

30453045
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
30463046

3047-
struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
3047+
//struct ggml_tensor * logits_mtp = res_mtp->get_logits();
3048+
30483049
//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
30493050

3050-
if (logits_mtp) {
3051-
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
3052-
}
3051+
//if (logits_mtp) {
3052+
// ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
3053+
//}
3054+
struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result");
3055+
3056+
3057+
llama_token token_id = 0; // The C++ variable to hold the result.
3058+
3059+
// ggml_backend_tensor_get is the function for GPU->CPU copies.
3060+
// We are copying a single 32-bit integer.
3061+
ggml_backend_tensor_get(
3062+
token_id_tensor,
3063+
&token_id, // Pointer to our C++ variable
3064+
0, // Starting offset in bytes
3065+
sizeof(llama_token) // Number of bytes to copy
3066+
);
30533067

30543068
ggml_backend_sched_free(sched);
3069+
3070+
return token_id;
30553071
}
30563072

src/llama-model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14100,6 +14100,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
1410014100
res->t_logits = cur;
1410114101

1410214102
ggml_build_forward_expand(gf, res->t_logits);
14103+
14104+
struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur);
14105+
ggml_set_name(token_id_tensor, "mtp_argmax_result");
14106+
ggml_build_forward_expand(gf, token_id_tensor);
1410314107
}
1410414108
};
1410514109

0 commit comments

Comments
 (0)