Skip to content

Commit 26d643c

Browse files
committed
Optimize speculative decoding performance of llama-server
1 parent c959f46 commit 26d643c

File tree

1 file changed

+96
-103
lines changed

1 file changed

+96
-103
lines changed

tools/server/server.cpp

Lines changed: 96 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -3038,6 +3038,7 @@ struct server_context {
30383038

30393039
// track if given slot can be batched with slots already in the batch
30403040
server_slot * slot_batched = nullptr;
3041+
bool speculative_accepted = false;
30413042

30423043
auto accept_special_token = [&](server_slot & slot, llama_token token) {
30433044
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
@@ -3049,6 +3050,97 @@ struct server_context {
30493050
continue;
30503051
}
30513052

3053+
if (slot.state == SLOT_STATE_GENERATING && slot.is_processing() && slot.can_speculate()) {
3054+
if (mctx) {
3055+
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
3056+
GGML_ABORT("not supported by multimodal");
3057+
}
3058+
3059+
// determine the max draft that fits the current slot state
3060+
int n_draft_max = slot.params.speculative.n_max;
3061+
3062+
// note: n_past is not yet increased for the `id` token sampled above
3063+
// also, need to leave space for 1 extra token to allow context shifts
3064+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3065+
3066+
if (slot.n_remaining > 0) {
3067+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3068+
}
3069+
3070+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3071+
3072+
if (n_draft_max >= slot.params.speculative.n_min) {
3073+
llama_token id = slot.sampled;
3074+
3075+
struct common_speculative_params params_spec;
3076+
params_spec.n_draft = n_draft_max;
3077+
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3078+
params_spec.p_min = slot.params.speculative.p_min;
3079+
3080+
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
3081+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
3082+
3083+
// ignore small drafts
3084+
if (slot.params.speculative.n_min <= (int) draft.size()) {
3085+
// keep track of total number of drafted tokens tested
3086+
slot.n_draft_total += draft.size();
3087+
3088+
// construct the speculation batch
3089+
common_batch_clear(slot.batch_spec);
3090+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3091+
3092+
for (size_t i = 0; i < draft.size(); ++i) {
3093+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3094+
}
3095+
3096+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3097+
3098+
llama_decode(ctx, slot.batch_spec);
3099+
3100+
// the accepted tokens from the speculation
3101+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3102+
3103+
slot.n_past += ids.size();
3104+
slot.n_decoded += ids.size();
3105+
3106+
// update how many tokens out of those tested were accepted
3107+
slot.n_draft_accepted += ids.size() - 1;
3108+
3109+
slot.cache_tokens.push_back(id);
3110+
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
3111+
3112+
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
3113+
3114+
for (size_t i = 0; i < ids.size(); ++i) {
3115+
completion_token_output result;
3116+
3117+
result.tok = ids[i];
3118+
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3119+
result.prob = 1.0f; // set later
3120+
3121+
// TODO: set result.probs
3122+
3123+
if (!process_token(result, slot)) {
3124+
// release slot because of stop condition
3125+
slot.release();
3126+
slot.print_timings();
3127+
send_final_response(slot);
3128+
metrics.on_prediction(slot);
3129+
break;
3130+
}
3131+
}
3132+
3133+
speculative_accepted = true;
3134+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3135+
continue;
3136+
}
3137+
3138+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3139+
} else {
3140+
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3141+
}
3142+
}
3143+
30523144
// check if we can batch this slot with the previous one
30533145
if (!slot_batched) {
30543146
slot_batched = &slot;
@@ -3372,7 +3464,10 @@ struct server_context {
33723464
}
33733465

33743466
if (batch.n_tokens == 0) {
3375-
SRV_WRN("%s", "no tokens to decode\n");
3467+
if (!speculative_accepted) {
3468+
SRV_WRN("%s", "no tokens to decode\n");
3469+
}
3470+
33763471
return;
33773472
}
33783473

@@ -3510,108 +3605,6 @@ struct server_context {
35103605
continue;
35113606
}
35123607
}
3513-
3514-
// do speculative decoding
3515-
for (auto & slot : slots) {
3516-
if (!slot.is_processing() || !slot.can_speculate()) {
3517-
continue;
3518-
}
3519-
3520-
if (slot.state != SLOT_STATE_GENERATING) {
3521-
continue;
3522-
}
3523-
3524-
if (mctx) {
3525-
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
3526-
GGML_ABORT("not supported by multimodal");
3527-
}
3528-
3529-
// determine the max draft that fits the current slot state
3530-
int n_draft_max = slot.params.speculative.n_max;
3531-
3532-
// note: n_past is not yet increased for the `id` token sampled above
3533-
// also, need to leave space for 1 extra token to allow context shifts
3534-
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3535-
3536-
if (slot.n_remaining > 0) {
3537-
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3538-
}
3539-
3540-
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3541-
3542-
if (n_draft_max < slot.params.speculative.n_min) {
3543-
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3544-
3545-
continue;
3546-
}
3547-
3548-
llama_token id = slot.sampled;
3549-
3550-
struct common_speculative_params params_spec;
3551-
params_spec.n_draft = n_draft_max;
3552-
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3553-
params_spec.p_min = slot.params.speculative.p_min;
3554-
3555-
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
3556-
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
3557-
3558-
// ignore small drafts
3559-
if (slot.params.speculative.n_min > (int) draft.size()) {
3560-
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3561-
3562-
continue;
3563-
}
3564-
3565-
// keep track of total number of drafted tokens tested
3566-
slot.n_draft_total += draft.size();
3567-
3568-
// construct the speculation batch
3569-
common_batch_clear(slot.batch_spec);
3570-
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3571-
3572-
for (size_t i = 0; i < draft.size(); ++i) {
3573-
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3574-
}
3575-
3576-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3577-
3578-
llama_decode(ctx, slot.batch_spec);
3579-
3580-
// the accepted tokens from the speculation
3581-
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3582-
3583-
slot.n_past += ids.size();
3584-
slot.n_decoded += ids.size();
3585-
3586-
// update how many tokens out of those tested were accepted
3587-
slot.n_draft_accepted += ids.size() - 1;
3588-
3589-
slot.cache_tokens.push_back(id);
3590-
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
3591-
3592-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
3593-
3594-
for (size_t i = 0; i < ids.size(); ++i) {
3595-
completion_token_output result;
3596-
3597-
result.tok = ids[i];
3598-
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3599-
result.prob = 1.0f; // set later
3600-
3601-
// TODO: set result.probs
3602-
3603-
if (!process_token(result, slot)) {
3604-
// release slot because of stop condition
3605-
slot.release();
3606-
slot.print_timings();
3607-
send_final_response(slot);
3608-
metrics.on_prediction(slot);
3609-
break;
3610-
}
3611-
}
3612-
3613-
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3614-
}
36153608
}
36163609

36173610
SRV_DBG("%s", "run slots completed\n");

0 commit comments

Comments
 (0)