Skip to content

Commit 5031366

Browse files
author
ochafik
committed
send final diff from server, to close off raw python arguments
1 parent a818114 commit 5031366

File tree

3 files changed

+92
-82
lines changed

3 files changed

+92
-82
lines changed

common/chat.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,35 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
364364
return result;
365365
}
366366

367+
template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
368+
json delta = json::object();
369+
// if (!diff.reasoning_content_delta.empty()) {
370+
// delta["reasoning_content"] = msg.reasoning_content;
371+
// }
372+
if (!diff.content_delta.empty()) {
373+
delta["content"] = diff.content_delta;
374+
}
375+
if (diff.tool_call_index != std::string::npos) {
376+
json function = json::object();
377+
if (!diff.tool_call_delta.name.empty()) {
378+
function["name"] = diff.tool_call_delta.name;
379+
}
380+
if (!diff.tool_call_delta.id.empty()) {
381+
function["id"] = diff.tool_call_delta.id;
382+
}
383+
if (!diff.tool_call_delta.arguments.empty()) {
384+
function["arguments"] = diff.tool_call_delta.arguments;
385+
}
386+
delta["tool_calls"] = json::array({
387+
json {
388+
{"index", diff.tool_call_index},
389+
{"function", function}
390+
}
391+
});
392+
}
393+
return delta;
394+
}
395+
367396
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
368397
if (use_jinja) {
369398
try {

common/chat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
193193
// T can be std::string containing JSON or nlohmann::ordered_json
194194
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
195195
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
196+
197+
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);

examples/server/server.cpp

Lines changed: 61 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,8 @@ struct server_task_result_cmpl_final : server_task_result {
642642
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
643643
std::string oaicompat_model;
644644
std::string oaicompat_cmpl_id;
645-
common_chat_syntax oaicompat_chat_syntax;
646645
common_chat_msg oaicompat_msg;
646+
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
647647

648648
virtual int get_index() override {
649649
return index;
@@ -794,14 +794,32 @@ struct server_task_result_cmpl_final : server_task_result {
794794
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
795795
}
796796

797-
json choice = json {
798-
{"finish_reason", finish_reason},
799-
{"index", 0},
800-
{"delta", json::object()}
801-
};
797+
json deltas = json::array();
798+
for (const auto & diff : oaicompat_msg_diffs) {
799+
deltas.push_back({
800+
{"choices", json::array({
801+
json {
802+
{"finish_reason", nullptr},
803+
{"index", 0},
804+
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
805+
},
806+
})},
807+
{"created", t},
808+
{"id", oaicompat_cmpl_id},
809+
{"model", oaicompat_model},
810+
{"system_fingerprint", build_info},
811+
{"object", "chat.completion.chunk"},
812+
});
813+
}
802814

803-
json ret = json {
804-
{"choices", json::array({choice})},
815+
deltas.push_back({
816+
{"choices", json::array({
817+
json {
818+
{"finish_reason", finish_reason},
819+
{"index", 0},
820+
{"delta", json::object()},
821+
},
822+
})},
805823
{"created", t},
806824
{"id", oaicompat_cmpl_id},
807825
{"model", oaicompat_model},
@@ -812,13 +830,13 @@ struct server_task_result_cmpl_final : server_task_result {
812830
{"prompt_tokens", n_prompt_tokens},
813831
{"total_tokens", n_decoded + n_prompt_tokens},
814832
}},
815-
};
833+
});
816834

817835
if (timings.prompt_n >= 0) {
818-
ret.push_back({"timings", timings.to_json()});
836+
deltas.back().push_back({"timings", timings.to_json()});
819837
}
820838

821-
return ret;
839+
return deltas;
822840
}
823841
};
824842

@@ -840,8 +858,7 @@ struct server_task_result_cmpl_partial : server_task_result {
840858
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
841859
std::string oaicompat_model;
842860
std::string oaicompat_cmpl_id;
843-
common_chat_msg oaicompat_previous_msg;
844-
common_chat_msg oaicompat_new_msg;
861+
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
845862

846863
virtual int get_index() override {
847864
return index;
@@ -926,9 +943,9 @@ struct server_task_result_cmpl_partial : server_task_result {
926943
std::time_t t = std::time(0);
927944
json choices;
928945

929-
std::vector<json> rets;
930-
auto add_ret = [&](const json & delta) {
931-
rets.push_back({
946+
std::vector<json> deltas;
947+
auto add_delta = [&](const json & delta) {
948+
deltas.push_back({
932949
{"choices", json::array({
933950
json {
934951
{"finish_reason", nullptr},
@@ -945,66 +962,31 @@ struct server_task_result_cmpl_partial : server_task_result {
945962
};
946963
// We have to send an initial update to conform to openai behavior
947964
if (first) {
948-
add_ret({
965+
add_delta({
949966
{"role", "assistant"},
950967
{"content", nullptr},
951968
});
952969
}
953970

954-
common_chat_msg previous_msg;
955-
if (oaicompat_previous_msg.empty()) {
956-
previous_msg.role = "assistant";
957-
} else {
958-
previous_msg = oaicompat_previous_msg;
959-
}
960-
if (!oaicompat_new_msg.empty()) {
961-
auto new_msg = oaicompat_new_msg;
962-
auto diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg);
963-
for (const auto & diff : diffs) {
964-
json delta = json::object();
965-
// if (!diff.reasoning_content_delta.empty()) {
966-
// delta["reasoning_content"] = msg.reasoning_content;
967-
// }
968-
if (!diff.content_delta.empty()) {
969-
delta["content"] = diff.content_delta;
970-
}
971-
if (diff.tool_call_index != std::string::npos) {
972-
json function = json::object();
973-
if (!diff.tool_call_delta.name.empty()) {
974-
function["name"] = diff.tool_call_delta.name;
975-
}
976-
if (!diff.tool_call_delta.id.empty()) {
977-
function["id"] = diff.tool_call_delta.id;
978-
}
979-
if (!diff.tool_call_delta.arguments.empty()) {
980-
function["arguments"] = diff.tool_call_delta.arguments;
981-
}
982-
delta["tool_calls"] = json::array({
983-
json {
984-
{"index", diff.tool_call_index},
985-
{"function", function}
986-
}
987-
});
988-
}
989-
add_ret(delta);
990-
}
971+
for (const auto & diff : oaicompat_msg_diffs) {
972+
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
991973
}
992974

993-
if (!rets.empty()) {
994-
GGML_ASSERT(rets[rets.size() - 1].at("choices").size() >= 1);
975+
if (!deltas.empty()) {
976+
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
995977

996978
if (prob_output.probs.size() > 0) {
997-
rets[rets.size() - 1].at("choices").at(0)["logprobs"] = json {
979+
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
998980
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
999981
};
1000982
}
1001983

1002984
if (timings.prompt_n >= 0) {
1003-
rets[rets.size() - 1].push_back({"timings", timings.to_json()});
985+
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
1004986
}
1005987
}
1006988

1007-
return rets;
989+
return deltas;
1008990
}
1009991
};
1010992

@@ -1268,7 +1250,7 @@ struct server_slot {
12681250

12691251
std::string generated_text;
12701252
llama_tokens generated_tokens;
1271-
common_chat_msg generated_msg;
1253+
common_chat_msg chat_msg;
12721254

12731255
llama_tokens cache_tokens;
12741256

@@ -1319,7 +1301,7 @@ struct server_slot {
13191301

13201302
generated_tokens.clear();
13211303
generated_token_probs.clear();
1322-
generated_msg = {};
1304+
chat_msg = {};
13231305
json_schema = json();
13241306
generated_tool_call_ids.clear();
13251307
}
@@ -1391,6 +1373,21 @@ struct server_slot {
13911373
return timings;
13921374
}
13931375

1376+
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
1377+
auto previous_msg = chat_msg;
1378+
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
1379+
auto new_msg = common_chat_parse(
1380+
generated_text,
1381+
/* is_partial= */ stop != STOP_TYPE_EOS,
1382+
params.oaicompat_chat_syntax);
1383+
if (!new_msg.empty()) {
1384+
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
1385+
chat_msg = new_msg;
1386+
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
1387+
}
1388+
return chat_msg;
1389+
}
1390+
13941391
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
13951392
size_t stop_pos = std::string::npos;
13961393

@@ -2358,18 +2355,7 @@ struct server_context {
23582355
res->oaicompat_model = slot.params.oaicompat_model;
23592356
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
23602357

2361-
auto previous_msg = slot.generated_msg;
2362-
SRV_DBG("Parsing chat message: %s\n", slot.generated_text.c_str());
2363-
auto new_msg = common_chat_parse(
2364-
slot.generated_text,
2365-
/* is_partial= */ true,
2366-
slot.params.oaicompat_chat_syntax);
2367-
if (!new_msg.empty()) {
2368-
new_msg.ensure_tool_call_ids_set(slot.generated_tool_call_ids, gen_tool_call_id);
2369-
slot.generated_msg = new_msg;
2370-
}
2371-
res->oaicompat_previous_msg = previous_msg;
2372-
res->oaicompat_new_msg = new_msg.empty() ? previous_msg : new_msg;
2358+
slot.update_chat_msg(res->oaicompat_msg_diffs);
23732359

23742360
// populate res.probs_output
23752361
if (slot.params.sampling.n_probs > 0) {
@@ -2390,7 +2376,7 @@ struct server_context {
23902376
res->id_slot = slot.id;
23912377

23922378
res->index = slot.index;
2393-
res->content = std::move(slot.generated_text);
2379+
res->content = slot.generated_text;
23942380
res->tokens = std::move(slot.generated_tokens);
23952381
res->timings = slot.get_timings();
23962382
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
@@ -2410,14 +2396,7 @@ struct server_context {
24102396
res->oaicompat = slot.params.oaicompat;
24112397
res->oaicompat_model = slot.params.oaicompat_model;
24122398
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2413-
2414-
SRV_DBG("Parsing chat message: %s\n", res->content.c_str());
2415-
res->oaicompat_msg = slot.generated_msg = common_chat_parse(
2416-
res->content,
2417-
/* is_partial= */ slot.stop == STOP_TYPE_LIMIT,
2418-
slot.params.oaicompat_chat_syntax);
2419-
res->oaicompat_msg.ensure_tool_call_ids_set(slot.generated_tool_call_ids, gen_tool_call_id);
2420-
res->oaicompat_chat_syntax = slot.params.oaicompat_chat_syntax;
2399+
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
24212400

24222401
// populate res.probs_output
24232402
if (slot.params.sampling.n_probs > 0) {

0 commit comments

Comments
 (0)