Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions common/chat-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,14 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear();
}

void common_chat_msg_parser::remove_content_suffix(size_t len) {
if (len == 0 || result_.content.empty()) {
return;
}
if (len >= result_.content.size()) {
result_.content.clear();
return;
}
result_.content.erase(result_.content.size() - len);
}
2 changes: 2 additions & 0 deletions common/chat-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,6 @@ class common_chat_msg_parser {
);

void clear_tools();

void remove_content_suffix(size_t len);
};
116 changes: 114 additions & 2 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <minja/chat-template.hpp>
#include <minja/minja.hpp>

#include <cctype>
#include <cstdio>
#include <exception>
#include <iostream>
Expand Down Expand Up @@ -43,6 +44,36 @@ static std::string string_diff(const std::string & last, const std::string & cur
return current.substr(last.size());
}

static std::string string_remove_tool_wrappers_suffix(const std::string & delta) {
size_t cut_pos = std::string::npos;
auto consider = [&](const std::string & token) {
auto pos = delta.find(token);
if (pos != std::string::npos) {
cut_pos = cut_pos == std::string::npos ? pos : std::min(cut_pos, pos);
}
};

consider("<tool_call>");
consider("<tool>");
consider("<tools>");
consider("<response>");
consider("<json>");
consider("<xml>");
consider("<JSON>");
consider("<function=");
consider("<function ");

if (cut_pos == std::string::npos) {
return delta;
}

std::string cleaned = delta.substr(0, cut_pos);
while (!cleaned.empty() && std::isspace(static_cast<unsigned char>(cleaned.back()))) {
cleaned.pop_back();
}
return cleaned;
}

static bool has_content_or_tool_calls(const common_chat_msg & msg) {
return !msg.content.empty() || !msg.tool_calls.empty();
}
Expand Down Expand Up @@ -89,8 +120,12 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
}
if (previous_msg.content != new_msg.content) {
auto & diff = diffs.emplace_back();
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
auto content_delta = string_diff(previous_msg.content, new_msg.content);
auto cleaned_delta = string_remove_tool_wrappers_suffix(content_delta);
if (!string_strip(cleaned_delta).empty()) {
auto & diff = diffs.emplace_back();
diff.content_delta = cleaned_delta;
}
}

if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
Expand Down Expand Up @@ -2116,6 +2151,17 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);

static const std::vector<std::string> wrapper_open_tags = {
"<tool_call>",
"<function_call>",
"<tool>",
"<tools>",
"<response>",
"<json>",
"<xml>",
"<JSON>",
};

while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
Expand All @@ -2142,24 +2188,90 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
} else {
auto prelude = res->prelude;
size_t suffix_to_remove = 0;
size_t cursor = prelude.size();
while (cursor > 0) {
size_t ws = cursor;
while (ws > 0 && std::isspace(static_cast<unsigned char>(prelude[ws - 1]))) {
--ws;
}
size_t trimmed_end = ws;
bool matched_wrapper = false;
for (const auto & tag : wrapper_open_tags) {
const size_t tag_len = tag.size();
if (trimmed_end >= tag_len && prelude.compare(trimmed_end - tag_len, tag_len, tag) == 0) {
matched_wrapper = true;
suffix_to_remove += cursor - (trimmed_end - tag_len);
cursor = trimmed_end - tag_len;
break;
}
}
if (!matched_wrapper) {
break;
}
}
if (suffix_to_remove > 0) {
while (cursor > 0 && std::isspace(static_cast<unsigned char>(prelude[cursor - 1]))) {
--cursor;
++suffix_to_remove;
}
builder.remove_content_suffix(suffix_to_remove);
}

auto function_name = builder.str(res->groups[4]);
if (function_name.empty()) {
function_name = builder.str(res->groups[5]);
}
GGML_ASSERT(!function_name.empty());

close_tag = "</function>";
bool had_block_start = false;
{
const auto backtick_pos = res->prelude.rfind("```");
if (backtick_pos != std::string::npos) {
had_block_start = true;
}
}

if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);

static const std::vector<std::string> wrapper_close_tags = {
"</tool_call>",
"</tool>",
"</tools>",
"</response>",
"</json>",
"</xml>",
"</JSON>",
};

while (true) {
builder.consume_spaces();
bool matched_wrapper = false;
for (const auto & wrapper_close : wrapper_close_tags) {
if (builder.try_consume_literal(wrapper_close)) {
matched_wrapper = true;
break;
}
}
if (!matched_wrapper) {
break;
}
}

builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
} else if (had_block_start) {
builder.try_consume_literal("```");
builder.consume_spaces();
}
}
}
Expand Down
20 changes: 20 additions & 0 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,16 @@ static void test_template_output_parsers() {
"</tool_call>",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
assert_msg_equals(message_assist_call_content,
common_chat_parse(
"Hello, world!\nWhat's up?\n"
"<tool_call>\n"
"<function=special_function>\n"
"{\"arg1\": 1}\n"
"</function>\n"
"</tool_call>\n",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
assert_msg_equals(
message_assist_call,
common_chat_parse(
Expand All @@ -840,6 +850,16 @@ static void test_template_output_parsers() {
"</function>",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
assert_msg_equals(
message_assist_call,
common_chat_parse(
"<tool_call>\n"
"<function=special_function>\n"
"{\"arg1\": 1}\n"
"</function>\n"
"</tool_call>\n",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
assert_msg_equals(
message_assist_call,
common_chat_parse(
Expand Down