Skip to content

Commit 7c84cc9

Browse files
committed
Optimize speculative decoding performance of llama-server
1 parent f5e96b3 commit 7c84cc9

File tree

1 file changed

+96
-103
lines changed

1 file changed

+96
-103
lines changed

tools/server/server.cpp

Lines changed: 96 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -3039,6 +3039,7 @@ struct server_context {
30393039

30403040
// track if given slot can be batched with slots already in the batch
30413041
server_slot * slot_batched = nullptr;
3042+
bool speculative_accepted = false;
30423043

30433044
auto accept_special_token = [&](server_slot & slot, llama_token token) {
30443045
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
@@ -3050,6 +3051,97 @@ struct server_context {
30503051
continue;
30513052
}
30523053

3054+
if (slot.state == SLOT_STATE_GENERATING && slot.is_processing() && slot.can_speculate()) {
3055+
if (mctx) {
3056+
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
3057+
GGML_ABORT("not supported by multimodal");
3058+
}
3059+
3060+
// determine the max draft that fits the current slot state
3061+
int n_draft_max = slot.params.speculative.n_max;
3062+
3063+
// note: n_past is not yet increased for the `id` token sampled above
3064+
// also, need to leave space for 1 extra token to allow context shifts
3065+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3066+
3067+
if (slot.n_remaining > 0) {
3068+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3069+
}
3070+
3071+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3072+
3073+
if (n_draft_max >= slot.params.speculative.n_min) {
3074+
llama_token id = slot.sampled;
3075+
3076+
struct common_speculative_params params_spec;
3077+
params_spec.n_draft = n_draft_max;
3078+
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3079+
params_spec.p_min = slot.params.speculative.p_min;
3080+
3081+
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
3082+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
3083+
3084+
// ignore small drafts
3085+
if (slot.params.speculative.n_min <= (int) draft.size()) {
3086+
// keep track of total number of drafted tokens tested
3087+
slot.n_draft_total += draft.size();
3088+
3089+
// construct the speculation batch
3090+
common_batch_clear(slot.batch_spec);
3091+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3092+
3093+
for (size_t i = 0; i < draft.size(); ++i) {
3094+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3095+
}
3096+
3097+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3098+
3099+
llama_decode(ctx, slot.batch_spec);
3100+
3101+
// the accepted tokens from the speculation
3102+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3103+
3104+
slot.n_past += ids.size();
3105+
slot.n_decoded += ids.size();
3106+
3107+
// update how many tokens out of those tested were accepted
3108+
slot.n_draft_accepted += ids.size() - 1;
3109+
3110+
slot.cache_tokens.push_back(id);
3111+
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
3112+
3113+
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
3114+
3115+
for (size_t i = 0; i < ids.size(); ++i) {
3116+
completion_token_output result;
3117+
3118+
result.tok = ids[i];
3119+
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3120+
result.prob = 1.0f; // set later
3121+
3122+
// TODO: set result.probs
3123+
3124+
if (!process_token(result, slot)) {
3125+
// release slot because of stop condition
3126+
slot.release();
3127+
slot.print_timings();
3128+
send_final_response(slot);
3129+
metrics.on_prediction(slot);
3130+
break;
3131+
}
3132+
}
3133+
3134+
speculative_accepted = true;
3135+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3136+
continue;
3137+
}
3138+
3139+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3140+
} else {
3141+
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3142+
}
3143+
}
3144+
30533145
// check if we can batch this slot with the previous one
30543146
if (!slot_batched) {
30553147
slot_batched = &slot;
@@ -3373,7 +3465,10 @@ struct server_context {
33733465
}
33743466

33753467
if (batch.n_tokens == 0) {
3376-
SRV_WRN("%s", "no tokens to decode\n");
3468+
if (!speculative_accepted) {
3469+
SRV_WRN("%s", "no tokens to decode\n");
3470+
}
3471+
33773472
return;
33783473
}
33793474

@@ -3514,108 +3609,6 @@ struct server_context {
35143609
continue;
35153610
}
35163611
}
3517-
3518-
// do speculative decoding
3519-
for (auto & slot : slots) {
3520-
if (!slot.is_processing() || !slot.can_speculate()) {
3521-
continue;
3522-
}
3523-
3524-
if (slot.state != SLOT_STATE_GENERATING) {
3525-
continue;
3526-
}
3527-
3528-
if (mctx) {
3529-
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
3530-
GGML_ABORT("not supported by multimodal");
3531-
}
3532-
3533-
// determine the max draft that fits the current slot state
3534-
int n_draft_max = slot.params.speculative.n_max;
3535-
3536-
// note: n_past is not yet increased for the `id` token sampled above
3537-
// also, need to leave space for 1 extra token to allow context shifts
3538-
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3539-
3540-
if (slot.n_remaining > 0) {
3541-
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3542-
}
3543-
3544-
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3545-
3546-
if (n_draft_max < slot.params.speculative.n_min) {
3547-
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3548-
3549-
continue;
3550-
}
3551-
3552-
llama_token id = slot.sampled;
3553-
3554-
struct common_speculative_params params_spec;
3555-
params_spec.n_draft = n_draft_max;
3556-
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3557-
params_spec.p_min = slot.params.speculative.p_min;
3558-
3559-
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
3560-
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
3561-
3562-
// ignore small drafts
3563-
if (slot.params.speculative.n_min > (int) draft.size()) {
3564-
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3565-
3566-
continue;
3567-
}
3568-
3569-
// keep track of total number of drafted tokens tested
3570-
slot.n_draft_total += draft.size();
3571-
3572-
// construct the speculation batch
3573-
common_batch_clear(slot.batch_spec);
3574-
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3575-
3576-
for (size_t i = 0; i < draft.size(); ++i) {
3577-
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3578-
}
3579-
3580-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3581-
3582-
llama_decode(ctx, slot.batch_spec);
3583-
3584-
// the accepted tokens from the speculation
3585-
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3586-
3587-
slot.n_past += ids.size();
3588-
slot.n_decoded += ids.size();
3589-
3590-
// update how many tokens out of those tested were accepted
3591-
slot.n_draft_accepted += ids.size() - 1;
3592-
3593-
slot.cache_tokens.push_back(id);
3594-
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
3595-
3596-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
3597-
3598-
for (size_t i = 0; i < ids.size(); ++i) {
3599-
completion_token_output result;
3600-
3601-
result.tok = ids[i];
3602-
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3603-
result.prob = 1.0f; // set later
3604-
3605-
// TODO: set result.probs
3606-
3607-
if (!process_token(result, slot)) {
3608-
// release slot because of stop condition
3609-
slot.release();
3610-
slot.print_timings();
3611-
send_final_response(slot);
3612-
metrics.on_prediction(slot);
3613-
break;
3614-
}
3615-
}
3616-
3617-
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3618-
}
36193612
}
36203613

36213614
SRV_DBG("%s", "run slots completed\n");

0 commit comments

Comments
 (0)