Skip to content

Commit c42712b

Browse files
authored
server: support multiple generations from one prompt (OAI "n" option) (ggml-org#17775)
* backend support * server: support multiple generations from one prompt (OAI "n" option) * fix invalid batch * format oai * clean up * disable ctx shift * add test * update comments * fix style * add n_cmpl to docs [no ci] * allowing using both n_cmpl and n
1 parent 09c7c50 commit c42712b

File tree

7 files changed

+146
-19
lines changed

7 files changed

+146
-19
lines changed

tools/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,8 @@ Note for `multimodal_data` in JSON object prompts. This should be an array of st
493493
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
494494
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
495495

496+
`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
497+
496498
`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
497499

498500
`stop`: Specify a JSON array of stopping strings.

tools/server/server-common.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,18 @@ int32_t server_tokens::process_chunk(
494494
return 0;
495495
}
496496

497+
server_tokens server_tokens::clone() const {
498+
server_tokens res;
499+
res.has_mtmd = has_mtmd;
500+
res.tokens = tokens;
501+
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
502+
size_t idx = it->first;
503+
const mtmd::input_chunk_ptr & chunk = it->second;
504+
res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get()));
505+
}
506+
return res;
507+
}
508+
497509
//
498510
// tokenizer and input processing utils
499511
//
@@ -745,12 +757,6 @@ json oaicompat_completion_params_parse(const json & body) {
745757
llama_params["stop"] = json_value(body, "stop", json::array());
746758
}
747759

748-
// Handle "n" field
749-
int n_choices = json_value(body, "n", 1);
750-
if (n_choices != 1) {
751-
throw std::runtime_error("Only one completion choice is allowed");
752-
}
753-
754760
// Handle "echo" field
755761
if (json_value(body, "echo", false)) {
756762
throw std::runtime_error("Only no echo is supported");
@@ -1049,12 +1055,6 @@ json oaicompat_chat_params_parse(
10491055
llama_params["chat_parser"] = chat_params.parser;
10501056
}
10511057

1052-
// Handle "n" field
1053-
int n_choices = json_value(body, "n", 1);
1054-
if (n_choices != 1) {
1055-
throw std::invalid_argument("Only one completion choice is allowed");
1056-
}
1057-
10581058
// Handle "logprobs" field
10591059
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
10601060
if (json_value(body, "logprobs", false)) {

tools/server/server-common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ struct server_tokens {
215215
llama_pos pos,
216216
int32_t seq_id,
217217
size_t & n_tokens_out) const;
218+
219+
server_tokens clone() const;
218220
};
219221

220222

tools/server/server-context.cpp

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
3535
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
3636
enum slot_state {
3737
SLOT_STATE_IDLE,
38-
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
38+
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
39+
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
3940
SLOT_STATE_PROCESSING_PROMPT,
4041
SLOT_STATE_DONE_PROMPT,
4142
SLOT_STATE_GENERATING,
@@ -254,6 +255,15 @@ struct server_slot {
254255
generated_token_probs.push_back(token);
255256
}
256257

258+
// note: a slot can also be either a parent or a child
259+
bool is_parent() const {
260+
return is_processing() && task->n_children > 0;
261+
}
262+
263+
bool is_child() const {
264+
return is_processing() && task->id_parent >= 0;
265+
}
266+
257267
void release() {
258268
if (is_processing()) {
259269
GGML_ASSERT(task);
@@ -383,6 +393,17 @@ struct server_slot {
383393

384394
return res;
385395
}
396+
397+
void copy_state_to(server_slot & other) const {
398+
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
399+
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
400+
other.n_decoded = n_decoded;
401+
other.n_remaining = n_remaining;
402+
other.i_batch = i_batch;
403+
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
404+
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
405+
other.prompt = prompt.clone();
406+
}
386407
};
387408

388409

@@ -1022,7 +1043,9 @@ struct server_context_impl {
10221043

10231044
slot.task = std::make_unique<const server_task>(std::move(task));
10241045

1025-
slot.state = SLOT_STATE_STARTED;
1046+
slot.state = slot.is_child()
1047+
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
1048+
: SLOT_STATE_STARTED;
10261049

10271050
SLT_INF(slot, "%s", "processing task\n");
10281051

@@ -1684,6 +1707,12 @@ struct server_context_impl {
16841707
GGML_ABORT("not supported by multimodal");
16851708
}
16861709

1710+
if (slot.is_parent() || slot.is_child()) {
1711+
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
1712+
slot.release();
1713+
continue;
1714+
}
1715+
16871716
// Shift context
16881717
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
16891718

@@ -2308,6 +2337,26 @@ struct server_context_impl {
23082337
n_batch = llama_n_batch(ctx);
23092338

23102339
for (auto & slot : slots) {
2340+
// may need to copy state to other slots
2341+
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
2342+
std::vector<server_slot *> child_slots;
2343+
for (auto & other : slots) {
2344+
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
2345+
child_slots.push_back(&other);
2346+
}
2347+
}
2348+
2349+
// we can only proceed if all child slots are having the correct tasks
2350+
if (child_slots.size() == slot.task->n_children) {
2351+
// copy state to the child slots
2352+
for (auto & child : child_slots) {
2353+
SLT_INF(slot, "copying state to child %d\n", child->id);
2354+
slot.copy_state_to(*child);
2355+
child->state = SLOT_STATE_DONE_PROMPT;
2356+
}
2357+
}
2358+
}
2359+
23112360
// optionally send prompt processing progress
23122361
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
23132362
if (slot.task->params.stream && slot.task->params.return_progress) {
@@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
25932642
}
25942643
tasks.reserve(inputs.size());
25952644
states.reserve(inputs.size());
2645+
int idx = 0;
25962646
for (size_t i = 0; i < inputs.size(); i++) {
25972647
server_task task = server_task(type);
25982648

25992649
task.id = ctx_server.queue_tasks.get_new_id();
2600-
task.index = i;
2650+
task.index = idx++;
26012651

26022652
task.tokens = std::move(inputs[i]);
26032653
task.params = server_task::params_from_json_cmpl(
@@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26122662
task.params.oaicompat_model = ctx_server.model_name;
26132663
states.push_back(task.params.oaicompat_chat_syntax);
26142664

2665+
if (task.params.n_cmpl > 1) {
2666+
task.n_children = task.params.n_cmpl - 1;
2667+
for (size_t j = 0; j < task.n_children; j++) {
2668+
server_task child = task.create_child(
2669+
task.id,
2670+
ctx_server.queue_tasks.get_new_id(),
2671+
idx++);
2672+
states.push_back(child.params.oaicompat_chat_syntax);
2673+
tasks.push_back(std::move(child));
2674+
}
2675+
}
2676+
26152677
tasks.push_back(std::move(task));
26162678
}
26172679

@@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26382700
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
26392701
arr.push_back(res->to_json());
26402702
}
2641-
// if single request, return single object instead of array
2642-
res->ok(arr.size() == 1 ? arr[0] : arr);
2703+
GGML_ASSERT(!arr.empty() && "empty results");
2704+
if (arr.size() == 1) {
2705+
// if single request, return single object instead of array
2706+
res->ok(arr[0]);
2707+
} else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
2708+
// if multiple results in OAI format, we need to re-format them
2709+
json & choices = arr[0]["choices"];
2710+
for (size_t i = 1; i < arr.size(); i++) {
2711+
choices.push_back(std::move(arr[i]["choices"][0]));
2712+
}
2713+
res->ok(arr[0]);
2714+
} else {
2715+
// multi-results, non-OAI compat
2716+
res->ok(arr);
2717+
}
26432718
}
26442719
} else {
26452720
// in streaming mode, the first error must be treated as non-stream response

tools/server/server-task.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ task_params server_task::params_from_json_cmpl(
175175
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
176176
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
177177
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
178+
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
178179
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
179180
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
180181
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
@@ -453,6 +454,10 @@ task_params server_task::params_from_json_cmpl(
453454
}
454455
}
455456

457+
if (params.n_cmpl > params_base.n_parallel) {
458+
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
459+
}
460+
456461
return params;
457462
}
458463

@@ -664,7 +669,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
664669

665670
json choice {
666671
{"finish_reason", finish_reason},
667-
{"index", 0},
672+
{"index", index},
668673
{"message", msg.to_json_oaicompat<json>()},
669674
};
670675

@@ -1064,7 +1069,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
10641069
{"choices", json::array({
10651070
json {
10661071
{"finish_reason", nullptr},
1067-
{"index", 0},
1072+
{"index", index},
10681073
{"delta", delta},
10691074
},
10701075
})},

tools/server/server-task.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ struct task_params {
5353
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
5454
int32_t n_predict = -1; // new tokens to predict
5555
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
56+
int32_t n_cmpl = 1; // number of completions to generate from this prompt
5657

5758
int64_t t_max_prompt_ms = -1; // TODO: implement
5859
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -89,6 +90,10 @@ struct server_task {
8990
int id_target = -1;
9091
int id_slot = -1;
9192

93+
// used by parallel sampling (multiple completions from same prompt)
94+
size_t n_children = 0; // number of tasks reusing this prompt
95+
int id_parent = -1;
96+
9297
// used by SERVER_TASK_TYPE_INFERENCE
9398
task_params params;
9499
server_tokens tokens;
@@ -130,6 +135,17 @@ struct server_task {
130135
}
131136
return ids;
132137
}
138+
139+
server_task create_child(int id_parent, int id_child, int idx) const {
140+
server_task copy;
141+
copy.id = id_child;
142+
copy.index = idx;
143+
copy.id_parent = id_parent;
144+
copy.params = params;
145+
copy.type = type;
146+
copy.tokens = tokens.clone();
147+
return copy;
148+
}
133149
};
134150

135151
struct result_timings {
@@ -466,6 +482,14 @@ struct server_prompt {
466482
int n_tokens() const {
467483
return tokens.size();
468484
}
485+
486+
server_prompt clone() const {
487+
return server_prompt {
488+
tokens.clone(),
489+
data,
490+
checkpoints
491+
};
492+
}
469493
};
470494

471495
struct server_prompt_cache {

tools/server/tests/unit/test_chat_completion.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,22 @@ def make_cmpl_request():
477477
assert last_progress["total"] > 0
478478
assert last_progress["processed"] == last_progress["total"]
479479
assert total_batch_count == batch_count
480+
481+
482+
def test_chat_completions_multiple_choices():
483+
global server
484+
server.start()
485+
res = server.make_request("POST", "/chat/completions", data={
486+
"max_tokens": 8,
487+
"n": 2,
488+
"messages": [
489+
{"role": "system", "content": "Book"},
490+
{"role": "user", "content": "What is the best book"},
491+
],
492+
})
493+
assert res.status_code == 200
494+
assert len(res.body["choices"]) == 2
495+
for choice in res.body["choices"]:
496+
assert "assistant" == choice["message"]["role"]
497+
assert match_regex("Suddenly", choice["message"]["content"])
498+
assert choice["finish_reason"] == "length"

0 commit comments

Comments
 (0)