Skip to content

Commit c4c10bf

Browse files
ngxsonggerganov
andauthored
server: move msg diffs tracking to HTTP thread (#17740)
* server: move msg diffs tracking to HTTP thread * wip * tool call tests ok * minor : style * cont : fix * move states to server_response_reader * add safe-guard * fix * fix 2 --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 817d743 commit c4c10bf

File tree

5 files changed

+165
-92
lines changed

5 files changed

+165
-92
lines changed

tools/server/server-context.cpp

Lines changed: 85 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ struct server_slot {
101101
std::string generated_text;
102102
llama_tokens generated_tokens;
103103

104-
common_chat_msg chat_msg;
105-
106104
std::vector<completion_token_output> generated_token_probs;
107105

108106
bool has_next_token = true;
@@ -153,9 +151,6 @@ struct server_slot {
153151

154152
llama_token sampled;
155153

156-
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
157-
std::vector<std::string> generated_tool_call_ids;
158-
159154
// stats
160155
size_t n_sent_text = 0; // number of sent text character
161156

@@ -183,13 +178,10 @@ struct server_slot {
183178
stop = STOP_TYPE_NONE;
184179
stopping_word = "";
185180
n_sent_text = 0;
186-
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
187181

188182
generated_tokens.clear();
189183
generated_token_probs.clear();
190-
chat_msg = {};
191184
json_schema = json();
192-
generated_tool_call_ids.clear();
193185

194186
// clear speculative decoding stats
195187
n_draft_total = 0;
@@ -302,23 +294,6 @@ struct server_slot {
302294
return timings;
303295
}
304296

305-
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
306-
GGML_ASSERT(task);
307-
308-
auto previous_msg = chat_msg;
309-
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
310-
auto new_msg = common_chat_parse(
311-
generated_text,
312-
/* is_partial= */ stop != STOP_TYPE_EOS,
313-
task->params.oaicompat_chat_syntax);
314-
if (!new_msg.empty()) {
315-
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
316-
chat_msg = new_msg;
317-
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
318-
}
319-
return chat_msg;
320-
}
321-
322297
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
323298
GGML_ASSERT(task);
324299

@@ -1284,8 +1259,6 @@ struct server_context_impl {
12841259
} else {
12851260
res->content = tkn.text_to_send;
12861261
res->tokens = { tkn.tok };
1287-
1288-
slot.update_chat_msg(res->oaicompat_msg_diffs);
12891262
}
12901263

12911264
res->n_decoded = slot.n_decoded;
@@ -1317,8 +1290,14 @@ struct server_context_impl {
13171290
res->id_slot = slot.id;
13181291

13191292
res->index = slot.task->index;
1320-
res->content = slot.generated_text;
1321-
res->tokens = std::move(slot.generated_tokens);
1293+
// in stream mode, content and tokens are already in last partial chunk
1294+
if (slot.task->params.stream) {
1295+
res->content = "";
1296+
res->tokens = llama_tokens{};
1297+
} else {
1298+
res->content = std::move(slot.generated_text);
1299+
res->tokens = std::move(slot.generated_tokens);
1300+
}
13221301
res->timings = slot.get_timings();
13231302
res->prompt = slot.task->tokens.detokenize(ctx, true);
13241303
res->response_fields = std::move(slot.task->params.response_fields);
@@ -1338,7 +1317,6 @@ struct server_context_impl {
13381317
res->res_type = slot.task->params.res_type;
13391318
res->oaicompat_model = slot.task->params.oaicompat_model;
13401319
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
1341-
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
13421320

13431321
// populate res.probs_output
13441322
if (slot.task->params.sampling.n_probs > 0) {
@@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
25962574
try {
25972575
std::vector<server_task> tasks;
25982576

2577+
// tracking generation state and partial tool calls
2578+
std::vector<task_result_state> states;
2579+
25992580
const auto & prompt = data.at("prompt");
26002581
// TODO: this log can become very long, put it behind a flag or think about a more compact format
26012582
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26112592
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
26122593
}
26132594
tasks.reserve(inputs.size());
2595+
states.reserve(inputs.size());
26142596
for (size_t i = 0; i < inputs.size(); i++) {
26152597
server_task task = server_task(type);
26162598

@@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26282610
task.params.res_type = res_type;
26292611
task.params.oaicompat_cmpl_id = completion_id;
26302612
task.params.oaicompat_model = ctx_server.model_name;
2613+
states.push_back(task.params.oaicompat_chat_syntax);
26312614

26322615
tasks.push_back(std::move(task));
26332616
}
26342617

2618+
rd.set_states(std::move(states));
26352619
rd.post_tasks(std::move(tasks));
26362620
} catch (const std::exception & e) {
26372621
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
@@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26572641
// if single request, return single object instead of array
26582642
res->ok(arr.size() == 1 ? arr[0] : arr);
26592643
}
2660-
26612644
} else {
26622645
// in streaming mode, the first error must be treated as non-stream response
26632646
// this is to match the OAI API behavior
@@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26762659
}
26772660

26782661
// next responses are streamed
2662+
// to be sent immediately
2663+
json first_result_json = first_result->to_json();
26792664
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2680-
res->data = format_anthropic_sse(first_result->to_json());
2665+
res->data = format_anthropic_sse(first_result_json);
26812666
} else {
2682-
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
2667+
res->data = format_oai_sse(first_result_json);
26832668
}
26842669
res->status = 200;
26852670
res->content_type = "text/event-stream";
26862671
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
2687-
if (should_stop()) {
2688-
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
2689-
return false; // should_stop condition met
2690-
}
2691-
2692-
if (!res_this->data.empty()) {
2693-
// flush the first chunk
2694-
output = std::move(res_this->data);
2695-
res_this->data.clear();
2696-
return true;
2697-
}
2698-
2699-
server_response_reader & rd = res_this->rd;
2700-
2701-
// check if there is more data
2702-
if (!rd.has_next()) {
2703-
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2704-
// Anthropic doesn't send [DONE], message_stop was already sent
2705-
output = "";
2706-
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
2707-
output = "data: [DONE]\n\n";
2708-
} else {
2709-
output = "";
2710-
}
2711-
SRV_DBG("%s", "all results received, terminating stream\n");
2712-
return false; // no more data, terminate
2713-
}
2714-
2715-
// receive subsequent results
2716-
auto result = rd.next(should_stop);
2717-
if (result == nullptr) {
2718-
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
2719-
return false; // should_stop condition met
2720-
}
2721-
2722-
// send the results
2723-
json res_json = result->to_json();
2724-
if (result->is_error()) {
2672+
static auto format_error = [](task_response_type res_type, const json & res_json) {
27252673
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2726-
output = format_anthropic_sse({
2674+
return format_anthropic_sse({
27272675
{"event", "error"},
27282676
{"data", res_json},
27292677
});
27302678
} else {
2731-
output = format_oai_sse(json {{ "error", res_json }});
2679+
return format_oai_sse(json {{ "error", res_json }});
27322680
}
2733-
SRV_DBG("%s", "error received during streaming, terminating stream\n");
2734-
return false; // terminate on error
2735-
} else {
2736-
GGML_ASSERT(
2737-
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
2738-
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
2739-
);
2740-
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2741-
output = format_anthropic_sse(res_json);
2681+
};
2682+
2683+
try {
2684+
if (should_stop()) {
2685+
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
2686+
return false; // should_stop condition met
2687+
}
2688+
2689+
if (!res_this->data.empty()) {
2690+
// flush the first chunk
2691+
output = std::move(res_this->data);
2692+
res_this->data.clear();
2693+
return true;
2694+
}
2695+
2696+
server_response_reader & rd = res_this->rd;
2697+
2698+
// check if there is more data
2699+
if (!rd.has_next()) {
2700+
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2701+
// Anthropic doesn't send [DONE], message_stop was already sent
2702+
output = "";
2703+
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
2704+
output = "data: [DONE]\n\n";
2705+
} else {
2706+
output = "";
2707+
}
2708+
SRV_DBG("%s", "all results received, terminating stream\n");
2709+
return false; // no more data, terminate
2710+
}
2711+
2712+
// receive subsequent results
2713+
auto result = rd.next(should_stop);
2714+
if (result == nullptr) {
2715+
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
2716+
return false; // should_stop condition met
2717+
}
2718+
2719+
// send the results
2720+
if (result->is_error()) {
2721+
json res_json = result->to_json();
2722+
output = format_error(res_type, res_json);
2723+
SRV_DBG("%s", "error received during streaming, terminating stream\n");
2724+
return false; // terminate on error
27422725
} else {
2743-
output = format_oai_sse(res_json);
2726+
GGML_ASSERT(
2727+
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
2728+
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
2729+
);
2730+
json res_json = result->to_json();
2731+
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2732+
output = format_anthropic_sse(res_json);
2733+
} else {
2734+
output = format_oai_sse(res_json);
2735+
}
27442736
}
2745-
}
27462737

2747-
// has next data, continue
2748-
return true;
2738+
// has next data, continue
2739+
return true;
2740+
2741+
} catch (const std::exception & e) {
2742+
json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
2743+
output = format_error(res_type, error_json);
2744+
2745+
// terminate on exception
2746+
return false;
2747+
}
27492748
};
27502749
}
27512750

tools/server/server-queue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ void server_response::terminate() {
271271
// server_response_reader
272272
//
273273

274+
void server_response_reader::set_states(std::vector<task_result_state> && states) {
275+
this->states = std::move(states);
276+
}
277+
274278
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
275279
id_tasks = server_task::get_list_id(tasks);
276280
queue_results.add_waiting_tasks(tasks);
@@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
298302
SRV_DBG("%s", "received error result, stopping further processing\n");
299303
return result;
300304
}
305+
if (!states.empty()) {
306+
// update the generation state if needed
307+
size_t idx = result->get_index();
308+
GGML_ASSERT(idx < states.size());
309+
result->update(states[idx]);
310+
}
301311
if (result->is_stop()) {
302312
received_count++;
303313
}

tools/server/server-queue.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <mutex>
88
#include <unordered_set>
99

10+
// struct for managing server tasks
11+
// in most cases, use server_response_reader to post new tasks and retrieve results
1012
struct server_queue {
1113
private:
1214
int id = 0;
@@ -67,6 +69,8 @@ struct server_queue {
6769
void cleanup_pending_task(int id_target);
6870
};
6971

72+
// struct for managing server responses
73+
// in most cases, use server_response_reader to retrieve results
7074
struct server_response {
7175
private:
7276
bool running = true;
@@ -120,13 +124,18 @@ struct server_response_reader {
120124
bool cancelled = false;
121125
int polling_interval_seconds;
122126

127+
// tracking generation state and partial tool calls
128+
// only used by streaming completions
129+
std::vector<task_result_state> states;
130+
123131
// should_stop function will be called each polling_interval_seconds
124132
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
125133
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
126134
~server_response_reader() {
127135
stop();
128136
}
129137

138+
void set_states(std::vector<task_result_state> && states);
130139
void post_tasks(std::vector<server_task> && tasks);
131140
bool has_next() const;
132141

0 commit comments

Comments
 (0)