Skip to content

Commit 5dac51a

Browse files
committed
Optimize speculative decoding performance of llama-server
1 parent 225e7a1 commit 5dac51a

File tree

1 file changed

+96
-103
lines changed

1 file changed

+96
-103
lines changed

tools/server/server.cpp

Lines changed: 96 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,6 +3040,7 @@ struct server_context {
30403040

30413041
// track if given slot can be batched with slots already in the batch
30423042
server_slot * slot_batched = nullptr;
3043+
bool speculative_accepted = false;
30433044

30443045
auto accept_special_token = [&](server_slot & slot, llama_token token) {
30453046
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
@@ -3051,6 +3052,97 @@ struct server_context {
30513052
continue;
30523053
}
30533054

3055+
if (slot.state == SLOT_STATE_GENERATING && slot.is_processing() && slot.can_speculate()) {
3056+
if (mctx) {
3057+
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
3058+
GGML_ABORT("not supported by multimodal");
3059+
}
3060+
3061+
// determine the max draft that fits the current slot state
3062+
int n_draft_max = slot.params.speculative.n_max;
3063+
3064+
// note: n_past is not yet increased for the `id` token sampled above
3065+
// also, need to leave space for 1 extra token to allow context shifts
3066+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3067+
3068+
if (slot.n_remaining > 0) {
3069+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3070+
}
3071+
3072+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3073+
3074+
if (n_draft_max >= slot.params.speculative.n_min) {
3075+
llama_token id = slot.sampled;
3076+
3077+
struct common_speculative_params params_spec;
3078+
params_spec.n_draft = n_draft_max;
3079+
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3080+
params_spec.p_min = slot.params.speculative.p_min;
3081+
3082+
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
3083+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
3084+
3085+
// ignore small drafts
3086+
if (slot.params.speculative.n_min <= (int) draft.size()) {
3087+
// keep track of total number of drafted tokens tested
3088+
slot.n_draft_total += draft.size();
3089+
3090+
// construct the speculation batch
3091+
common_batch_clear(slot.batch_spec);
3092+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3093+
3094+
for (size_t i = 0; i < draft.size(); ++i) {
3095+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3096+
}
3097+
3098+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3099+
3100+
llama_decode(ctx, slot.batch_spec);
3101+
3102+
// the accepted tokens from the speculation
3103+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3104+
3105+
slot.n_past += ids.size();
3106+
slot.n_decoded += ids.size();
3107+
3108+
// update how many tokens out of those tested were accepted
3109+
slot.n_draft_accepted += ids.size() - 1;
3110+
3111+
slot.cache_tokens.push_back(id);
3112+
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
3113+
3114+
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
3115+
3116+
for (size_t i = 0; i < ids.size(); ++i) {
3117+
completion_token_output result;
3118+
3119+
result.tok = ids[i];
3120+
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3121+
result.prob = 1.0f; // set later
3122+
3123+
// TODO: set result.probs
3124+
3125+
if (!process_token(result, slot)) {
3126+
// release slot because of stop condition
3127+
slot.release();
3128+
slot.print_timings();
3129+
send_final_response(slot);
3130+
metrics.on_prediction(slot);
3131+
break;
3132+
}
3133+
}
3134+
3135+
speculative_accepted = true;
3136+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3137+
continue;
3138+
}
3139+
3140+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3141+
} else {
3142+
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3143+
}
3144+
}
3145+
30543146
// check if we can batch this slot with the previous one
30553147
if (!slot_batched) {
30563148
slot_batched = &slot;
@@ -3374,7 +3466,10 @@ struct server_context {
33743466
}
33753467

33763468
if (batch.n_tokens == 0) {
3377-
SRV_WRN("%s", "no tokens to decode\n");
3469+
if (!speculative_accepted) {
3470+
SRV_WRN("%s", "no tokens to decode\n");
3471+
}
3472+
33783473
return;
33793474
}
33803475

@@ -3515,108 +3610,6 @@ struct server_context {
35153610
continue;
35163611
}
35173612
}
3518-
3519-
// do speculative decoding
3520-
for (auto & slot : slots) {
3521-
if (!slot.is_processing() || !slot.can_speculate()) {
3522-
continue;
3523-
}
3524-
3525-
if (slot.state != SLOT_STATE_GENERATING) {
3526-
continue;
3527-
}
3528-
3529-
if (mctx) {
3530-
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
3531-
GGML_ABORT("not supported by multimodal");
3532-
}
3533-
3534-
// determine the max draft that fits the current slot state
3535-
int n_draft_max = slot.params.speculative.n_max;
3536-
3537-
// note: n_past is not yet increased for the `id` token sampled above
3538-
// also, need to leave space for 1 extra token to allow context shifts
3539-
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3540-
3541-
if (slot.n_remaining > 0) {
3542-
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3543-
}
3544-
3545-
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3546-
3547-
if (n_draft_max < slot.params.speculative.n_min) {
3548-
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3549-
3550-
continue;
3551-
}
3552-
3553-
llama_token id = slot.sampled;
3554-
3555-
struct common_speculative_params params_spec;
3556-
params_spec.n_draft = n_draft_max;
3557-
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3558-
params_spec.p_min = slot.params.speculative.p_min;
3559-
3560-
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
3561-
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
3562-
3563-
// ignore small drafts
3564-
if (slot.params.speculative.n_min > (int) draft.size()) {
3565-
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3566-
3567-
continue;
3568-
}
3569-
3570-
// keep track of total number of drafted tokens tested
3571-
slot.n_draft_total += draft.size();
3572-
3573-
// construct the speculation batch
3574-
common_batch_clear(slot.batch_spec);
3575-
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3576-
3577-
for (size_t i = 0; i < draft.size(); ++i) {
3578-
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3579-
}
3580-
3581-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3582-
3583-
llama_decode(ctx, slot.batch_spec);
3584-
3585-
// the accepted tokens from the speculation
3586-
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3587-
3588-
slot.n_past += ids.size();
3589-
slot.n_decoded += ids.size();
3590-
3591-
// update how many tokens out of those tested were accepted
3592-
slot.n_draft_accepted += ids.size() - 1;
3593-
3594-
slot.cache_tokens.push_back(id);
3595-
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
3596-
3597-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
3598-
3599-
for (size_t i = 0; i < ids.size(); ++i) {
3600-
completion_token_output result;
3601-
3602-
result.tok = ids[i];
3603-
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3604-
result.prob = 1.0f; // set later
3605-
3606-
// TODO: set result.probs
3607-
3608-
if (!process_token(result, slot)) {
3609-
// release slot because of stop condition
3610-
slot.release();
3611-
slot.print_timings();
3612-
send_final_response(slot);
3613-
metrics.on_prediction(slot);
3614-
break;
3615-
}
3616-
}
3617-
3618-
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3619-
}
36203613
}
36213614

36223615
SRV_DBG("%s", "run slots completed\n");

0 commit comments

Comments
 (0)