Skip to content

Commit ae6b870

Browse files
author
ochafik
committed
test & fix array message.content
1 parent 1c6168b commit ae6b870

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

common/chat.cpp

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
128128
}
129129

130130
template <>
131-
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs) {
131+
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
132132
json messages = json::array();
133133
for (const auto & msg : msgs) {
134134
if (!msg.content.empty() && !msg.content_parts.empty()) {
@@ -140,12 +140,27 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
140140
if (!msg.content.empty()) {
141141
jmsg["content"] = msg.content;
142142
} else if (!msg.content_parts.empty()) {
143-
auto & parts = jmsg["content"] = json::array();
144-
for (const auto & part : msg.content_parts) {
145-
parts.push_back({
146-
{"type", part.type},
147-
{"text", part.text},
148-
});
143+
if (concat_typed_text) {
144+
std::string text;
145+
for (const auto & part : msg.content_parts) {
146+
if (part.type != "text") {
147+
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
148+
continue;
149+
}
150+
if (!text.empty()) {
151+
text += '\n';
152+
}
153+
text += part.text;
154+
}
155+
jmsg["content"] = text;
156+
} else {
157+
auto & parts = jmsg["content"] = json::array();
158+
for (const auto & part : msg.content_parts) {
159+
parts.push_back({
160+
{"type", part.type},
161+
{"text", part.text},
162+
});
163+
}
149164
}
150165
} else {
151166
jmsg["content"] = json(); // null
@@ -1388,21 +1403,21 @@ static common_chat_params common_chat_templates_apply_jinja(
13881403
const struct common_chat_templates_inputs & inputs)
13891404
{
13901405
templates_params params;
1391-
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages);
1406+
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
1407+
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
1408+
? *tmpls->template_tool_use
1409+
: *tmpls->template_default;
1410+
const auto & src = tmpl.source();
1411+
const auto & caps = tmpl.original_caps();
1412+
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
13921413
params.add_generation_prompt = inputs.add_generation_prompt;
13931414
params.extract_reasoning = inputs.extract_reasoning;
1394-
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
13951415
params.tool_choice = inputs.tool_choice;
13961416
params.grammar = inputs.grammar;
13971417
if (!inputs.json_schema.empty()) {
13981418
params.json_schema = json::parse(inputs.json_schema);
13991419
}
1400-
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
1401-
? *tmpls->template_tool_use
1402-
: *tmpls->template_default;
1403-
const auto & src = tmpl.source();
1404-
const auto & caps = tmpl.original_caps();
1405-
1420+
14061421
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
14071422
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
14081423
params.parallel_tool_calls = false;
@@ -1487,6 +1502,10 @@ static common_chat_params common_chat_templates_apply_legacy(
14871502
for (const auto & msg : inputs.messages) {
14881503
auto content = msg.content;
14891504
for (const auto & part : msg.content_parts) {
1505+
if (part.type != "text") {
1506+
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
1507+
continue;
1508+
}
14901509
if (!content.empty()) {
14911510
content += "\n";;
14921511
}

common/chat.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
123123
// Parses a JSON array of messages in OpenAI's chat completion API format.
124124
// T can be std::string containing JSON or nlohmann::ordered_json
125125
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
126-
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs);
126+
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
127127

128128
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
129129
// T can be std::string containing JSON or nlohmann::ordered_json

examples/server/tests/unit/test_chat_completion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def create_server():
2121
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
2222
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
2323
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
24+
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
25+
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
2426
]
2527
)
2628
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
@@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
4446
assert res.body["usage"]["completion_tokens"] == n_predicted
4547
choice = res.body["choices"][0]
4648
assert "assistant" == choice["message"]["role"]
47-
assert match_regex(re_content, choice["message"]["content"])
49+
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
4850
assert choice["finish_reason"] == finish_reason
4951

5052

0 commit comments

Comments
 (0)