Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,12 @@ struct result_timings {
double predicted_per_token_ms;
double predicted_per_second;

// Optional speculative metrics - only included when > 0
int32_t draft_n = 0;
int32_t draft_n_accepted = 0;

json to_json() const {
return {
json base = {
{"prompt_n", prompt_n},
{"prompt_ms", prompt_ms},
{"prompt_per_token_ms", prompt_per_token_ms},
Expand All @@ -501,6 +505,13 @@ struct result_timings {
{"predicted_per_token_ms", predicted_per_token_ms},
{"predicted_per_second", predicted_per_second},
};

if (draft_n > 0) {
base["draft_n"] = draft_n;
base["draft_n_accepted"] = draft_n_accepted;
}

return base;
}
};

Expand Down Expand Up @@ -1299,6 +1310,10 @@ struct server_slot {

std::function<void(int)> callback_on_release;

// Speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted

void reset() {
SLT_DBG(*this, "%s", "\n");

Expand All @@ -1315,6 +1330,10 @@ struct server_slot {

generated_tokens.clear();
generated_token_probs.clear();

// clear speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
}

bool is_non_causal() const {
Expand Down Expand Up @@ -1381,6 +1400,12 @@ struct server_slot {
timings.predicted_per_token_ms = t_token_generation / n_decoded;
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;

// Add speculative metrics
if (n_draft_total > 0) {
timings.draft_n = n_draft_total;
timings.draft_n_accepted = n_draft_accepted;
}

return timings;
}

Expand Down Expand Up @@ -1428,6 +1453,15 @@ struct server_slot {
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
t_token_generation, n_decoded, t_gen, n_gen_second,
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);

if (n_draft_total > 0) {
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
SLT_INF(*this,
"\n"
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
draft_ratio, n_draft_accepted, n_draft_total
);
}
}

json to_json() const {
Expand Down Expand Up @@ -3290,6 +3324,9 @@ struct server_context {

llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);

// keep track of total number of tokens generated in the draft
slot.n_draft_total += draft.size();

// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
Expand All @@ -3315,6 +3352,9 @@ struct server_context {
slot.n_past += ids.size();
slot.n_decoded += ids.size();

// update how many tokens out of draft was accepted
slot.n_draft_accepted += ids.size() - 1;

slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);

Expand Down
Loading