Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

struct common_speculative {
struct llama_context * ctx;

struct common_sampler * smpl;
struct common_sampler * smpl_infill;

llama_batch batch;
llama_tokens prompt;
Expand All @@ -20,49 +22,48 @@ struct common_speculative {
struct common_speculative * common_speculative_init(
struct llama_context * ctx_dft) {
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .smpl_infill = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
};

// TODO: optimize or pass from outside?
#if 0
{
common_params_sampling params;
params.no_perf = false;

params.top_k = 40;
params.top_p = 0.9;
params.top_k = 10;

params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_INFILL,
};

result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
#else

{
common_params_sampling params;
params.no_perf = false;

params.top_k = 10;
params.top_k = 40;
params.top_p = 0.9;

params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_INFILL,
};

result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params);
}
#endif

return result;
}

void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
common_sampler_free(spec->smpl_infill);

llama_batch_free(spec->batch);

Expand Down Expand Up @@ -133,7 +134,7 @@ llama_tokens common_speculative_gen_draft(
llama_token id_last) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & smpl = spec->smpl;
auto & smpl = params.infill ? spec->smpl_infill : spec->smpl;
auto & prompt = spec->prompt;

int reuse_i = 0;
Expand Down
2 changes: 2 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ struct common_speculative_params {
int n_reuse = 256;

float p_min = 0.9f; // min probabiliy required to accept a token in the draft

bool infill = false; // use infill sampling (useful for FIM)
};

struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
Expand Down
1 change: 1 addition & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2315,6 +2315,7 @@ struct server_context {
params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
params_spec.infill = slot.inf_type == SERVER_TASK_INF_TYPE_INFILL;

llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);

Expand Down
Loading