@@ -101,8 +101,6 @@ struct server_slot {
101101 std::string generated_text;
102102 llama_tokens generated_tokens;
103103
104- common_chat_msg chat_msg;
105-
106104 std::vector<completion_token_output> generated_token_probs;
107105
108106 bool has_next_token = true ;
@@ -153,9 +151,6 @@ struct server_slot {
153151
154152 llama_token sampled;
155153
156- common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
157- std::vector<std::string> generated_tool_call_ids;
158-
159154 // stats
160155 size_t n_sent_text = 0 ; // number of sent text character
161156
@@ -183,13 +178,10 @@ struct server_slot {
183178 stop = STOP_TYPE_NONE;
184179 stopping_word = " " ;
185180 n_sent_text = 0 ;
186- chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
187181
188182 generated_tokens.clear ();
189183 generated_token_probs.clear ();
190- chat_msg = {};
191184 json_schema = json ();
192- generated_tool_call_ids.clear ();
193185
194186 // clear speculative decoding stats
195187 n_draft_total = 0 ;
@@ -302,23 +294,6 @@ struct server_slot {
302294 return timings;
303295 }
304296
305- const common_chat_msg & update_chat_msg (std::vector<common_chat_msg_diff> & diffs) {
306- GGML_ASSERT (task);
307-
308- auto previous_msg = chat_msg;
309- SRV_DBG (" Parsing chat message: %s\n " , generated_text.c_str ());
310- auto new_msg = common_chat_parse (
311- generated_text,
312- /* is_partial= */ stop != STOP_TYPE_EOS,
313- task->params .oaicompat_chat_syntax );
314- if (!new_msg.empty ()) {
315- new_msg.set_tool_call_ids (generated_tool_call_ids, gen_tool_call_id);
316- chat_msg = new_msg;
317- diffs = common_chat_msg_diff::compute_diffs (previous_msg, new_msg.empty () ? previous_msg : new_msg);
318- }
319- return chat_msg;
320- }
321-
322297 size_t find_stopping_strings (const std::string & text, const size_t last_token_size, bool is_full_stop) {
323298 GGML_ASSERT (task);
324299
@@ -1284,8 +1259,6 @@ struct server_context_impl {
12841259 } else {
12851260 res->content = tkn.text_to_send ;
12861261 res->tokens = { tkn.tok };
1287-
1288- slot.update_chat_msg (res->oaicompat_msg_diffs );
12891262 }
12901263
12911264 res->n_decoded = slot.n_decoded ;
@@ -1317,8 +1290,14 @@ struct server_context_impl {
13171290 res->id_slot = slot.id ;
13181291
13191292 res->index = slot.task ->index ;
1320- res->content = slot.generated_text ;
1321- res->tokens = std::move (slot.generated_tokens );
1293+ // in stream mode, content and tokens are already in last partial chunk
1294+ if (slot.task ->params .stream ) {
1295+ res->content = " " ;
1296+ res->tokens = llama_tokens{};
1297+ } else {
1298+ res->content = std::move (slot.generated_text );
1299+ res->tokens = std::move (slot.generated_tokens );
1300+ }
13221301 res->timings = slot.get_timings ();
13231302 res->prompt = slot.task ->tokens .detokenize (ctx, true );
13241303 res->response_fields = std::move (slot.task ->params .response_fields );
@@ -1338,7 +1317,6 @@ struct server_context_impl {
13381317 res->res_type = slot.task ->params .res_type ;
13391318 res->oaicompat_model = slot.task ->params .oaicompat_model ;
13401319 res->oaicompat_cmpl_id = slot.task ->params .oaicompat_cmpl_id ;
1341- res->oaicompat_msg = slot.update_chat_msg (res->oaicompat_msg_diffs );
13421320
13431321 // populate res.probs_output
13441322 if (slot.task ->params .sampling .n_probs > 0 ) {
@@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
25962574 try {
25972575 std::vector<server_task> tasks;
25982576
2577+ // tracking generation state and partial tool calls
2578+ std::vector<task_result_state> states;
2579+
25992580 const auto & prompt = data.at (" prompt" );
26002581 // TODO: this log can become very long, put it behind a flag or think about a more compact format
26012582 // SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
@@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26112592 inputs = tokenize_input_prompts (ctx_server.vocab , ctx_server.mctx , prompt, true , true );
26122593 }
26132594 tasks.reserve (inputs.size ());
2595+ states.reserve (inputs.size ());
26142596 for (size_t i = 0 ; i < inputs.size (); i++) {
26152597 server_task task = server_task (type);
26162598
@@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26282610 task.params .res_type = res_type;
26292611 task.params .oaicompat_cmpl_id = completion_id;
26302612 task.params .oaicompat_model = ctx_server.model_name ;
2613+ states.push_back (task.params .oaicompat_chat_syntax );
26312614
26322615 tasks.push_back (std::move (task));
26332616 }
26342617
2618+ rd.set_states (std::move (states));
26352619 rd.post_tasks (std::move (tasks));
26362620 } catch (const std::exception & e) {
26372621 res->error (format_error_response (e.what (), ERROR_TYPE_INVALID_REQUEST));
@@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26572641 // if single request, return single object instead of array
26582642 res->ok (arr.size () == 1 ? arr[0 ] : arr);
26592643 }
2660-
26612644 } else {
26622645 // in streaming mode, the first error must be treated as non-stream response
26632646 // this is to match the OAI API behavior
@@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
26762659 }
26772660
26782661 // next responses are streamed
2662+ // to be sent immediately
2663+ json first_result_json = first_result->to_json ();
26792664 if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2680- res->data = format_anthropic_sse (first_result-> to_json () );
2665+ res->data = format_anthropic_sse (first_result_json );
26812666 } else {
2682- res->data = format_oai_sse (first_result-> to_json ()); // to be sent immediately
2667+ res->data = format_oai_sse (first_result_json);
26832668 }
26842669 res->status = 200 ;
26852670 res->content_type = " text/event-stream" ;
26862671 res->next = [res_this = res.get (), res_type, &should_stop](std::string & output) -> bool {
2687- if (should_stop ()) {
2688- SRV_DBG (" %s" , " stopping streaming due to should_stop condition\n " );
2689- return false ; // should_stop condition met
2690- }
2691-
2692- if (!res_this->data .empty ()) {
2693- // flush the first chunk
2694- output = std::move (res_this->data );
2695- res_this->data .clear ();
2696- return true ;
2697- }
2698-
2699- server_response_reader & rd = res_this->rd ;
2700-
2701- // check if there is more data
2702- if (!rd.has_next ()) {
2703- if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2704- // Anthropic doesn't send [DONE], message_stop was already sent
2705- output = " " ;
2706- } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
2707- output = " data: [DONE]\n\n " ;
2708- } else {
2709- output = " " ;
2710- }
2711- SRV_DBG (" %s" , " all results received, terminating stream\n " );
2712- return false ; // no more data, terminate
2713- }
2714-
2715- // receive subsequent results
2716- auto result = rd.next (should_stop);
2717- if (result == nullptr ) {
2718- SRV_DBG (" %s" , " stopping streaming due to should_stop condition\n " );
2719- return false ; // should_stop condition met
2720- }
2721-
2722- // send the results
2723- json res_json = result->to_json ();
2724- if (result->is_error ()) {
2672+ static auto format_error = [](task_response_type res_type, const json & res_json) {
27252673 if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2726- output = format_anthropic_sse ({
2674+ return format_anthropic_sse ({
27272675 {" event" , " error" },
27282676 {" data" , res_json},
27292677 });
27302678 } else {
2731- output = format_oai_sse (json {{ " error" , res_json }});
2679+ return format_oai_sse (json {{ " error" , res_json }});
27322680 }
2733- SRV_DBG (" %s" , " error received during streaming, terminating stream\n " );
2734- return false ; // terminate on error
2735- } else {
2736- GGML_ASSERT (
2737- dynamic_cast <server_task_result_cmpl_partial*>(result.get ()) != nullptr
2738- || dynamic_cast <server_task_result_cmpl_final*>(result.get ()) != nullptr
2739- );
2740- if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2741- output = format_anthropic_sse (res_json);
2681+ };
2682+
2683+ try {
2684+ if (should_stop ()) {
2685+ SRV_DBG (" %s" , " stopping streaming due to should_stop condition\n " );
2686+ return false ; // should_stop condition met
2687+ }
2688+
2689+ if (!res_this->data .empty ()) {
2690+ // flush the first chunk
2691+ output = std::move (res_this->data );
2692+ res_this->data .clear ();
2693+ return true ;
2694+ }
2695+
2696+ server_response_reader & rd = res_this->rd ;
2697+
2698+ // check if there is more data
2699+ if (!rd.has_next ()) {
2700+ if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2701+ // Anthropic doesn't send [DONE], message_stop was already sent
2702+ output = " " ;
2703+ } else if (res_type != TASK_RESPONSE_TYPE_NONE) {
2704+ output = " data: [DONE]\n\n " ;
2705+ } else {
2706+ output = " " ;
2707+ }
2708+ SRV_DBG (" %s" , " all results received, terminating stream\n " );
2709+ return false ; // no more data, terminate
2710+ }
2711+
2712+ // receive subsequent results
2713+ auto result = rd.next (should_stop);
2714+ if (result == nullptr ) {
2715+ SRV_DBG (" %s" , " stopping streaming due to should_stop condition\n " );
2716+ return false ; // should_stop condition met
2717+ }
2718+
2719+ // send the results
2720+ if (result->is_error ()) {
2721+ json res_json = result->to_json ();
2722+ output = format_error (res_type, res_json);
2723+ SRV_DBG (" %s" , " error received during streaming, terminating stream\n " );
2724+ return false ; // terminate on error
27422725 } else {
2743- output = format_oai_sse (res_json);
2726+ GGML_ASSERT (
2727+ dynamic_cast <server_task_result_cmpl_partial*>(result.get ()) != nullptr
2728+ || dynamic_cast <server_task_result_cmpl_final*>(result.get ()) != nullptr
2729+ );
2730+ json res_json = result->to_json ();
2731+ if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
2732+ output = format_anthropic_sse (res_json);
2733+ } else {
2734+ output = format_oai_sse (res_json);
2735+ }
27442736 }
2745- }
27462737
2747- // has next data, continue
2748- return true ;
2738+ // has next data, continue
2739+ return true ;
2740+
2741+ } catch (const std::exception & e) {
2742+ json error_json = format_error_response (e.what (), ERROR_TYPE_SERVER);
2743+ output = format_error (res_type, error_json);
2744+
2745+ // terminate on exception
2746+ return false ;
2747+ }
27492748 };
27502749 }
27512750
0 commit comments