Skip to content

Commit 8742ce0

Browse files
feat: apply logits + greedy sampler
1 parent 5a5bce8 commit 8742ce0

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

common/sampling.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,3 +582,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
582582

583583
return samplers;
584584
}
585+
586+
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
587+
llama_sampler_apply(gsmpl->chain, cur_p);
588+
}

common/sampling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,5 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
105105

106106
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
107107
const char * grammar_kind, const char * grammar_data);
108+
109+
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p);

common/speculative.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,15 +379,22 @@ llama_token mtp_speculative_gen_draft(
379379

380380
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
381381

382-
llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true);
382+
const llama_model * model = llama_get_model(ctx);
383+
const llama_vocab * vocab = llama_model_get_vocab(model);
384+
const int n_vocab = llama_n_vocab(vocab);
383385

384-
const auto * cur_p = common_sampler_get_candidates(smpl);
385-
for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
386-
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
387-
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
386+
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
387+
388+
cur_p->size = n_vocab;
389+
for (int i = 0; i < n_vocab; ++i) {
390+
cur_p->data[i].id = i;
391+
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
388392
}
393+
cur_p->sorted = false;
394+
395+
common_sampler_apply_chain(smpl, cur_p);
389396

390-
common_sampler_accept(smpl, id, true);
397+
const llama_token id = cur_p->data[0].id;
391398

392399
llama_batch_free(batch);
393400

0 commit comments

Comments
 (0)