Skip to content

Commit 2dc2918

Browse files
committed
Include speculative decoding stats when timings_per_token is true
New fields added to the `timings` object: - draft_n : number of draft tokens generated - draft_accepted_n : number of draft tokens accepted - draft_accept_ratio: ratio of accepted/generated
1 parent f17a3bb commit 2dc2918

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

examples/server/server.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,13 @@ struct result_timings {
489489
double predicted_per_token_ms;
490490
double predicted_per_second;
491491

492+
// Optional speculative metrics - only included when > 0
493+
int32_t draft_n = 0;
494+
int32_t draft_accepted_n = 0;
495+
double draft_accept_ratio = 0;
496+
492497
json to_json() const {
493-
return {
498+
json base = {
494499
{"prompt_n", prompt_n},
495500
{"prompt_ms", prompt_ms},
496501
{"prompt_per_token_ms", prompt_per_token_ms},
@@ -501,6 +506,14 @@ struct result_timings {
501506
{"predicted_per_token_ms", predicted_per_token_ms},
502507
{"predicted_per_second", predicted_per_second},
503508
};
509+
510+
if (draft_n > 0) {
511+
base["draft_n"] = draft_n;
512+
base["draft_accepted_n"] = draft_accepted_n;
513+
base["draft_accept_ratio"] = draft_accept_ratio;
514+
}
515+
516+
return base;
504517
}
505518
};
506519

@@ -1299,6 +1312,11 @@ struct server_slot {
12991312

13001313
std::function<void(int)> callback_on_release;
13011314

1315+
// Speculative decoding stats
1316+
int32_t n_draft_total = 0; // Total draft tokens generated
1317+
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
1318+
float draft_accept_ratio = 0.f; // n_draft_accepted/n_draft_total
1319+
13021320
void reset() {
13031321
SLT_DBG(*this, "%s", "\n");
13041322

@@ -1315,6 +1333,11 @@ struct server_slot {
13151333

13161334
generated_tokens.clear();
13171335
generated_token_probs.clear();
1336+
1337+
// clear speculative decoding stats
1338+
n_draft_total = 0;
1339+
n_draft_accepted = 0;
1340+
draft_accept_ratio = 0.f;
13181341
}
13191342

13201343
bool is_non_causal() const {
@@ -1381,6 +1404,13 @@ struct server_slot {
13811404
timings.predicted_per_token_ms = t_token_generation / n_decoded;
13821405
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
13831406

1407+
// Add speculative metrics
1408+
if (n_draft_total > 0) {
1409+
timings.draft_n = n_draft_total;
1410+
timings.draft_accepted_n = n_draft_accepted;
1411+
timings.draft_accept_ratio = draft_accept_ratio;
1412+
}
1413+
13841414
return timings;
13851415
}
13861416

@@ -3290,6 +3320,8 @@ struct server_context {
32903320

32913321
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
32923322

3323+
slot.n_draft_total += draft.size();
3324+
32933325
// ignore small drafts
32943326
if (slot.params.speculative.n_min > (int) draft.size()) {
32953327
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
@@ -3339,6 +3371,12 @@ struct server_context {
33393371
}
33403372
}
33413373

3374+
// Update speculative metrics
3375+
slot.n_draft_accepted += ids.size() - 1; // exclude last sampled token
3376+
if (slot.n_draft_total > 0) {
3377+
slot.draft_accept_ratio = (float)slot.n_draft_accepted / slot.n_draft_total;
3378+
}
3379+
33423380
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
33433381
}
33443382
}

0 commit comments

Comments
 (0)