Skip to content

Commit b83cae0

Browse files
committed
speculative : add infill mode
ggml-ci
1 parent 0eb4e12 commit b83cae0

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

common/speculative.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
struct common_speculative {
1313
struct llama_context * ctx;
14+
1415
struct common_sampler * smpl;
16+
struct common_sampler * smpl_infill;
1517

1618
llama_batch batch;
1719
llama_tokens prompt;
@@ -20,49 +22,48 @@ struct common_speculative {
2022
struct common_speculative * common_speculative_init(
2123
struct llama_context * ctx_dft) {
2224
auto * result = new common_speculative {
23-
/* .ctx = */ ctx_dft,
24-
/* .smpl = */ nullptr,
25-
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
26-
/* .prompt = */ {},
25+
/* .ctx = */ ctx_dft,
26+
/* .smpl = */ nullptr,
27+
/* .smpl_infill = */ nullptr,
28+
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
29+
/* .prompt = */ {},
2730
};
2831

29-
// TODO: optimize or pass from outside?
30-
#if 0
3132
{
3233
common_params_sampling params;
3334
params.no_perf = false;
3435

35-
params.top_k = 40;
36-
params.top_p = 0.9;
36+
params.top_k = 10;
3737

3838
params.samplers = {
3939
COMMON_SAMPLER_TYPE_TOP_K,
40-
COMMON_SAMPLER_TYPE_TOP_P,
41-
COMMON_SAMPLER_TYPE_INFILL,
4240
};
4341

4442
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
4543
}
46-
#else
44+
4745
{
4846
common_params_sampling params;
4947
params.no_perf = false;
5048

51-
params.top_k = 10;
49+
params.top_k = 40;
50+
params.top_p = 0.9;
5251

5352
params.samplers = {
5453
COMMON_SAMPLER_TYPE_TOP_K,
54+
COMMON_SAMPLER_TYPE_TOP_P,
55+
COMMON_SAMPLER_TYPE_INFILL,
5556
};
5657

57-
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
58+
result->smpl_infill = common_sampler_init(llama_get_model(ctx_dft), params);
5859
}
59-
#endif
6060

6161
return result;
6262
}
6363

6464
void common_speculative_free(struct common_speculative * spec) {
6565
common_sampler_free(spec->smpl);
66+
common_sampler_free(spec->smpl_infill);
6667

6768
llama_batch_free(spec->batch);
6869

@@ -133,7 +134,7 @@ llama_tokens common_speculative_gen_draft(
133134
llama_token id_last) {
134135
auto & batch = spec->batch;
135136
auto & ctx = spec->ctx;
136-
auto & smpl = spec->smpl;
137+
auto & smpl = params.infill ? spec->smpl_infill : spec->smpl;
137138
auto & prompt = spec->prompt;
138139

139140
int reuse_i = 0;

common/speculative.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ struct common_speculative_params {
1010
int n_reuse = 256;
1111

1212
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
13+
14+
bool infill = false; // use infill sampling (useful for FIM)
1315
};
1416

1517
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);

examples/server/server.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,6 +2315,7 @@ struct server_context {
23152315
params_spec.n_draft = slot.params.speculative.n_max;
23162316
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
23172317
params_spec.p_min = slot.params.speculative.p_min;
2318+
params_spec.infill = slot.inf_type == SERVER_TASK_INF_TYPE_INFILL;
23182319

23192320
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
23202321

0 commit comments

Comments
 (0)