Skip to content

Commit dd3e54f

Browse files
author
Eero Lihavainen
committed
feat: openai-style lookup decoding for server
1 parent 06c2b15 commit dd3e54f

File tree

2 files changed

+228
-5
lines changed

2 files changed

+228
-5
lines changed

examples/server/server.cpp

Lines changed: 166 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ struct server_task {
197197
// used by SERVER_TASK_TYPE_INFERENCE
198198
slot_params params;
199199
llama_tokens prompt_tokens;
200+
llama_tokens prediction_tokens;
200201
int id_selected_slot = -1;
201202

202203
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
@@ -604,6 +605,7 @@ struct server_task_result_cmpl_final : server_task_result {
604605
int32_t n_decoded;
605606
int32_t n_prompt_tokens;
606607
int32_t n_tokens_cached;
608+
int32_t n_lookup_used;
607609
bool has_new_line;
608610
std::string stopping_word;
609611
stop_type stop = STOP_TYPE_NONE;
@@ -660,6 +662,7 @@ struct server_task_result_cmpl_final : server_task_result {
660662
{"stopping_word", stopping_word},
661663
{"tokens_cached", n_tokens_cached},
662664
{"timings", timings.to_json()},
665+
{"prediction_tokens_accepted", n_lookup_used},
663666
};
664667
if (!stream && !probs_output.empty()) {
665668
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
@@ -695,7 +698,10 @@ struct server_task_result_cmpl_final : server_task_result {
695698
{"usage", json {
696699
{"completion_tokens", n_decoded},
697700
{"prompt_tokens", n_prompt_tokens},
698-
{"total_tokens", n_decoded + n_prompt_tokens}
701+
{"total_tokens", n_decoded + n_prompt_tokens},
702+
{"completion_tokens_details", json {
703+
{"accepted_prediction_tokens", n_lookup_used },
704+
}}
699705
}},
700706
{"id", oaicompat_cmpl_id}
701707
};
@@ -771,11 +777,14 @@ struct server_task_result_cmpl_final : server_task_result {
771777
{"usage", json {
772778
{"completion_tokens", n_decoded},
773779
{"prompt_tokens", n_prompt_tokens},
774-
{"total_tokens", n_decoded + n_prompt_tokens}
780+
{"total_tokens", n_decoded + n_prompt_tokens},
781+
{"completion_tokens_details", json {
782+
{"accepted_prediction_tokens", n_lookup_used },
783+
}}
775784
}},
776785
{"id", oaicompat_cmpl_id}
777786
};
778-
787+
779788
// extra fields for debugging purposes
780789
if (verbose) {
781790
res["__verbose"] = to_json_non_oaicompat();
@@ -811,6 +820,9 @@ struct server_task_result_cmpl_final : server_task_result {
811820
{"completion_tokens", n_decoded},
812821
{"prompt_tokens", n_prompt_tokens},
813822
{"total_tokens", n_decoded + n_prompt_tokens},
823+
{"completion_tokens_details", json {
824+
{"accepted_prediction_tokens", n_lookup_used },
825+
}}
814826
}},
815827
};
816828

@@ -1235,16 +1247,22 @@ struct server_slot {
12351247
int32_t n_ctx = 0; // context size per slot
12361248
int32_t n_past = 0;
12371249
int32_t n_decoded = 0;
1250+
int32_t n_lookup_used = 0;
12381251
int32_t n_remaining = -1;
12391252
int32_t i_batch = -1;
12401253
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
12411254

1255+
// for "predicted outputs"
1256+
int32_t lookup_n_adaptive = 1;
1257+
int32_t lookup_index = 0;
1258+
12421259
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
12431260
int32_t n_prompt_tokens = 0;
12441261
int32_t n_prompt_tokens_processed = 0;
12451262

12461263
// input prompt tokens
12471264
llama_tokens prompt_tokens;
1265+
llama_tokens prediction_tokens;
12481266

12491267
size_t last_nl_pos = 0;
12501268

@@ -1912,9 +1930,8 @@ struct server_context {
19121930
slot.n_ctx = n_ctx_slot;
19131931
slot.n_predict = params_base.n_predict;
19141932

1933+
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
19151934
if (model_dft) {
1916-
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
1917-
19181935
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
19191936
if (slot.ctx_dft == nullptr) {
19201937
SRV_ERR("%s", "failed to create draft context\n");
@@ -2034,6 +2051,7 @@ struct server_context {
20342051
slot.task_type = task.type;
20352052
slot.params = std::move(task.params);
20362053
slot.prompt_tokens = std::move(task.prompt_tokens);
2054+
slot.prediction_tokens = std::move(task.prediction_tokens);
20372055

20382056
if (!are_lora_equal(task.params.lora, slot.lora)) {
20392057
// if lora is changed, we cannot reuse cached tokens
@@ -2345,6 +2363,7 @@ struct server_context {
23452363
res->n_decoded = slot.n_decoded;
23462364
res->n_prompt_tokens = slot.n_prompt_tokens;
23472365
res->n_tokens_cached = slot.n_past;
2366+
res->n_lookup_used = slot.n_lookup_used;
23482367
res->has_new_line = slot.has_new_line;
23492368
res->stopping_word = slot.stopping_word;
23502369
res->stop = slot.stop;
@@ -3217,6 +3236,137 @@ struct server_context {
32173236
}
32183237
}
32193238

3239+
// apply "predicted outputs" i.e. user-specified speculation
3240+
// using a simple lookup decoding method
3241+
for (auto & slot : slots) {
3242+
// don't use lookup if we are also using a draft model
3243+
if (slot.can_speculate() || !slot.is_processing() || slot.prediction_tokens.size() < 2) {
3244+
continue;
3245+
}
3246+
if (slot.state != SLOT_STATE_GENERATING) {
3247+
continue;
3248+
}
3249+
3250+
// adaptive speculation window:
3251+
// increase window size every time all drafted tokens were accepted,
3252+
// otherwise reset to zero
3253+
auto draft_start_pos = 1;
3254+
bool found = false;
3255+
// first look for a match from the expected position
3256+
SLT_DBG(slot, "Looking up prediction tokens at index %d/%d\n", (int) slot.lookup_index, (int) slot.prediction_tokens.size());
3257+
if (slot.lookup_index > 0 &&
3258+
slot.lookup_index < static_cast<int32_t>(slot.prediction_tokens.size()) &&
3259+
slot.prediction_tokens[slot.lookup_index-1] == slot.sampled) {
3260+
found = true;
3261+
draft_start_pos = slot.lookup_index;
3262+
// TODO what is a good scaling law here?
3263+
// going for too large windows too fast will likely fail,
3264+
// but also too small windows in the beginning hurt perf
3265+
slot.lookup_n_adaptive = std::max(16, slot.lookup_n_adaptive*2);
3266+
} else {
3267+
// find first match in prediction_tokens
3268+
slot.lookup_n_adaptive = 1; // default
3269+
for (; draft_start_pos < static_cast<int32_t>(slot.prediction_tokens.size()); draft_start_pos++) {
3270+
if (slot.prediction_tokens[draft_start_pos-1] == slot.sampled) {
3271+
found = true;
3272+
break;
3273+
}
3274+
}
3275+
}
3276+
if (!found) continue;
3277+
3278+
// we erase the accepted tokens later, so we're looking for the same position next time
3279+
// increment by one because the next token will be generated
3280+
slot.lookup_index = draft_start_pos + 1;
3281+
3282+
llama_tokens draft = std::vector(
3283+
slot.prediction_tokens.begin() + draft_start_pos,
3284+
slot.prediction_tokens.end()
3285+
);
3286+
3287+
// determine the max draft that fits the current slot state
3288+
int n_draft_max = slot.lookup_n_adaptive;
3289+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3290+
3291+
if (slot.n_remaining > 0) {
3292+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3293+
}
3294+
3295+
n_draft_max = std::min(n_draft_max, static_cast<int>(draft.size()));
3296+
// NOTE: we use speculative.n_max here as the upper limit, but
3297+
// in general we want to allow large drafts, as opposed to when
3298+
// using a draft model. But this is linked to `slot.batch_spec`
3299+
// size also.
3300+
n_draft_max = std::min(n_draft_max, slot.params.speculative.n_max);
3301+
3302+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3303+
3304+
draft.resize(n_draft_max);
3305+
3306+
llama_token id = slot.sampled;
3307+
3308+
// construct the speculation batch
3309+
common_batch_clear(slot.batch_spec);
3310+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3311+
3312+
for (size_t i = 0; i < draft.size(); ++i) {
3313+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3314+
}
3315+
3316+
llama_decode(ctx, slot.batch_spec);
3317+
3318+
// the accepted tokens from the speculation
3319+
// TODO can we stream these? Would be nice to reduce jankiness in UIs
3320+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3321+
3322+
const auto n_accepted = ids.size() - 1;
3323+
slot.n_lookup_used += n_accepted;
3324+
3325+
if (n_accepted > 0) {
3326+
// remove the prediction tokens that were used + the next token
3327+
// (because it will be generated)
3328+
slot.prediction_tokens.erase(
3329+
slot.prediction_tokens.begin() + draft_start_pos,
3330+
std::min(
3331+
slot.prediction_tokens.end(),
3332+
slot.prediction_tokens.begin() + draft_start_pos + n_accepted + 1
3333+
)
3334+
);
3335+
if (n_accepted < draft.size()) {
3336+
// reset speculation as we didn't use the full draft
3337+
slot.lookup_n_adaptive = 1;
3338+
}
3339+
}
3340+
3341+
for (size_t i = 0; i < ids.size(); ++i) {
3342+
// NOTE: we need to update these here to avoid stopping early
3343+
slot.n_past++;
3344+
slot.n_decoded++;
3345+
completion_token_output result;
3346+
3347+
result.tok = ids[i];
3348+
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3349+
result.prob = 1.0f; // set later
3350+
3351+
// TODO: set result.probs
3352+
if (!process_token(result, slot)) {
3353+
// release slot because of stop condition
3354+
slot.release();
3355+
slot.print_timings();
3356+
send_final_response(slot);
3357+
metrics.on_prediction(slot);
3358+
break;
3359+
}
3360+
}
3361+
3362+
slot.cache_tokens.push_back(id);
3363+
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
3364+
3365+
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
3366+
3367+
SLT_DBG(slot, "accepted %d/%d prediction tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3368+
}
3369+
32203370
// do speculative decoding
32213371
for (auto & slot : slots) {
32223372
if (!slot.is_processing() || !slot.can_speculate()) {
@@ -3838,10 +3988,17 @@ int main(int argc, char ** argv) {
38383988

38393989
try {
38403990
const auto & prompt = data.at("prompt");
3991+
const auto & prediction_obj = json_value(data, "prediction", json());
3992+
const auto & prediction = json_value(prediction_obj, "content", std::string());
38413993
// TODO: this log can become very long, put it behind a flag or think about a more compact format
38423994
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
38433995

38443996
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
3997+
std::vector<llama_tokens> tokenized_prediction;
3998+
if (!prediction.empty()) {
3999+
tokenized_prediction = tokenize_input_prompts(ctx_server.vocab, prediction, true, true);
4000+
}
4001+
38454002
tasks.reserve(tokenized_prompts.size());
38464003
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
38474004
server_task task = server_task(type);
@@ -3850,6 +4007,10 @@ int main(int argc, char ** argv) {
38504007
task.index = i;
38514008

38524009
task.prompt_tokens = std::move(tokenized_prompts[i]);
4010+
4011+
if (!tokenized_prediction.empty()) {
4012+
task.prediction_tokens = std::vector(tokenized_prediction[0].begin(), tokenized_prediction[0].end());
4013+
}
38534014
task.params = server_task::params_from_json_cmpl(
38544015
ctx_server.ctx,
38554016
ctx_server.params_base,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
from utils import *
3+
4+
5+
@pytest.fixture(scope="module", autouse=True)
6+
def create_server():
7+
global server
8+
server = ServerPreset.tinyllama2()
9+
server.draft_max = 1024
10+
server.debug = True
11+
12+
13+
def test_with_and_without_prediced_outputs():
14+
global server
15+
server.start()
16+
res = server.make_request("POST", "/v1/chat/completions", data={
17+
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
18+
"temperature": 0.0,
19+
"top_k": 1,
20+
})
21+
assert res.status_code == 200
22+
assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 0
23+
content_no_pred = res.body["choices"][0]["message"]["content"]
24+
server.stop()
25+
26+
server.start()
27+
res = server.make_request("POST", "/v1/chat/completions", data={
28+
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
29+
"temperature": 0.0,
30+
"top_k": 1,
31+
"prediction": {"content": '''"Here?" Annabyed.
32+
"Okay, Annabyes!" Annabyed.
33+
As Annagged, Annap came and said,'''}
34+
})
35+
assert res.status_code == 200
36+
assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 54
37+
content_pred = res.body["choices"][0]["message"]["content"]
38+
server.stop()
39+
40+
assert content_no_pred == content_pred
41+
42+
43+
@pytest.mark.parametrize("n_slots,n_requests", [
44+
(1, 2),
45+
(2, 2),
46+
])
47+
def test_multi_requests_parallel(n_slots: int, n_requests: int):
48+
global server
49+
server.n_slots = n_slots
50+
server.start()
51+
tasks = []
52+
for _ in range(n_requests):
53+
res = server.make_request("POST", "/v1/chat/completions", data={
54+
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
55+
"temperature": 0.0,
56+
"top_k": 1,
57+
"prediction": {"content": " believe the meaning of life is"}
58+
})
59+
results = parallel_function_calls(tasks)
60+
for res in results:
61+
assert res.status_code == 200
62+
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])

0 commit comments

Comments
 (0)