Skip to content

Commit ab775c2

Browse files
committed
Optimize speculative decoding performance of llama-server
1 parent adef817 commit ab775c2

File tree

1 file changed

+96
-103
lines changed

1 file changed

+96
-103
lines changed

tools/server/server.cpp

Lines changed: 96 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -3045,6 +3045,7 @@ struct server_context {
30453045

30463046
// track if given slot can be batched with slots already in the batch
30473047
server_slot * slot_batched = nullptr;
3048+
bool speculative_accepted = false;
30483049

30493050
auto accept_special_token = [&](server_slot & slot, llama_token token) {
30503051
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
@@ -3056,6 +3057,97 @@ struct server_context {
30563057
continue;
30573058
}
30583059

3060+
if (slot.state == SLOT_STATE_GENERATING && slot.is_processing() && slot.can_speculate()) {
3061+
if (mctx) {
3062+
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
3063+
GGML_ABORT("not supported by multimodal");
3064+
}
3065+
3066+
// determine the max draft that fits the current slot state
3067+
int n_draft_max = slot.params.speculative.n_max;
3068+
3069+
// note: n_past is not yet increased for the `id` token sampled above
3070+
// also, need to leave space for 1 extra token to allow context shifts
3071+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3072+
3073+
if (slot.n_remaining > 0) {
3074+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3075+
}
3076+
3077+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3078+
3079+
if (n_draft_max >= slot.params.speculative.n_min) {
3080+
llama_token id = slot.sampled;
3081+
3082+
struct common_speculative_params params_spec;
3083+
params_spec.n_draft = n_draft_max;
3084+
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3085+
params_spec.p_min = slot.params.speculative.p_min;
3086+
3087+
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
3088+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
3089+
3090+
// ignore small drafts
3091+
if (slot.params.speculative.n_min <= (int) draft.size()) {
3092+
// keep track of total number of drafted tokens tested
3093+
slot.n_draft_total += draft.size();
3094+
3095+
// construct the speculation batch
3096+
common_batch_clear(slot.batch_spec);
3097+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3098+
3099+
for (size_t i = 0; i < draft.size(); ++i) {
3100+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3101+
}
3102+
3103+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3104+
3105+
llama_decode(ctx, slot.batch_spec);
3106+
3107+
// the accepted tokens from the speculation
3108+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3109+
3110+
slot.n_past += ids.size();
3111+
slot.n_decoded += ids.size();
3112+
3113+
// update how many tokens out of those tested were accepted
3114+
slot.n_draft_accepted += ids.size() - 1;
3115+
3116+
slot.cache_tokens.push_back(id);
3117+
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
3118+
3119+
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
3120+
3121+
for (size_t i = 0; i < ids.size(); ++i) {
3122+
completion_token_output result;
3123+
3124+
result.tok = ids[i];
3125+
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3126+
result.prob = 1.0f; // set later
3127+
3128+
// TODO: set result.probs
3129+
3130+
if (!process_token(result, slot)) {
3131+
// release slot because of stop condition
3132+
slot.release();
3133+
slot.print_timings();
3134+
send_final_response(slot);
3135+
metrics.on_prediction(slot);
3136+
break;
3137+
}
3138+
}
3139+
3140+
speculative_accepted = true;
3141+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3142+
continue;
3143+
}
3144+
3145+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3146+
} else {
3147+
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3148+
}
3149+
}
3150+
30593151
// check if we can batch this slot with the previous one
30603152
if (!slot_batched) {
30613153
slot_batched = &slot;
@@ -3379,7 +3471,10 @@ struct server_context {
33793471
}
33803472

33813473
if (batch.n_tokens == 0) {
3382-
SRV_WRN("%s", "no tokens to decode\n");
3474+
if (!speculative_accepted) {
3475+
SRV_WRN("%s", "no tokens to decode\n");
3476+
}
3477+
33833478
return;
33843479
}
33853480

@@ -3520,108 +3615,6 @@ struct server_context {
35203615
continue;
35213616
}
35223617
}
3523-
3524-
// do speculative decoding
3525-
for (auto & slot : slots) {
3526-
if (!slot.is_processing() || !slot.can_speculate()) {
3527-
continue;
3528-
}
3529-
3530-
if (slot.state != SLOT_STATE_GENERATING) {
3531-
continue;
3532-
}
3533-
3534-
if (mctx) {
3535-
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
3536-
GGML_ABORT("not supported by multimodal");
3537-
}
3538-
3539-
// determine the max draft that fits the current slot state
3540-
int n_draft_max = slot.params.speculative.n_max;
3541-
3542-
// note: n_past is not yet increased for the `id` token sampled above
3543-
// also, need to leave space for 1 extra token to allow context shifts
3544-
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3545-
3546-
if (slot.n_remaining > 0) {
3547-
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3548-
}
3549-
3550-
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3551-
3552-
if (n_draft_max < slot.params.speculative.n_min) {
3553-
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3554-
3555-
continue;
3556-
}
3557-
3558-
llama_token id = slot.sampled;
3559-
3560-
struct common_speculative_params params_spec;
3561-
params_spec.n_draft = n_draft_max;
3562-
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3563-
params_spec.p_min = slot.params.speculative.p_min;
3564-
3565-
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
3566-
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
3567-
3568-
// ignore small drafts
3569-
if (slot.params.speculative.n_min > (int) draft.size()) {
3570-
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3571-
3572-
continue;
3573-
}
3574-
3575-
// keep track of total number of drafted tokens tested
3576-
slot.n_draft_total += draft.size();
3577-
3578-
// construct the speculation batch
3579-
common_batch_clear(slot.batch_spec);
3580-
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3581-
3582-
for (size_t i = 0; i < draft.size(); ++i) {
3583-
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3584-
}
3585-
3586-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3587-
3588-
llama_decode(ctx, slot.batch_spec);
3589-
3590-
// the accepted tokens from the speculation
3591-
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3592-
3593-
slot.n_past += ids.size();
3594-
slot.n_decoded += ids.size();
3595-
3596-
// update how many tokens out of those tested were accepted
3597-
slot.n_draft_accepted += ids.size() - 1;
3598-
3599-
slot.cache_tokens.push_back(id);
3600-
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
3601-
3602-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
3603-
3604-
for (size_t i = 0; i < ids.size(); ++i) {
3605-
completion_token_output result;
3606-
3607-
result.tok = ids[i];
3608-
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3609-
result.prob = 1.0f; // set later
3610-
3611-
// TODO: set result.probs
3612-
3613-
if (!process_token(result, slot)) {
3614-
// release slot because of stop condition
3615-
slot.release();
3616-
slot.print_timings();
3617-
send_final_response(slot);
3618-
metrics.on_prediction(slot);
3619-
break;
3620-
}
3621-
}
3622-
3623-
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3624-
}
36253618
}
36263619

36273620
SRV_DBG("%s", "run slots completed\n");

0 commit comments

Comments
 (0)