diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 96ba8f533ef1b..7f28557ddedc5 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -391,3 +391,14 @@ std::optional common_chat_msg_parse void common_chat_msg_parser::clear_tools() { result_.tool_calls.clear(); } + +void common_chat_msg_parser::remove_content_suffix(size_t len) { + if (len == 0 || result_.content.empty()) { + return; + } + if (len >= result_.content.size()) { + result_.content.clear(); + return; + } + result_.content.erase(result_.content.size() - len); +} diff --git a/common/chat-parser.h b/common/chat-parser.h index 0e64c341a50aa..d5e4fd19b43ad 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -117,4 +117,6 @@ class common_chat_msg_parser { ); void clear_tools(); + + void remove_content_suffix(size_t len); }; diff --git a/common/chat.cpp b/common/chat.cpp index e2bacdcf52753..3441a5d8fc417 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -43,6 +44,36 @@ static std::string string_diff(const std::string & last, const std::string & cur return current.substr(last.size()); } +static std::string string_remove_tool_wrappers_suffix(const std::string & delta) { + size_t cut_pos = std::string::npos; + auto consider = [&](const std::string & token) { + auto pos = delta.find(token); + if (pos != std::string::npos) { + cut_pos = cut_pos == std::string::npos ? pos : std::min(cut_pos, pos); + } + }; + + consider(""); + consider(""); + consider(""); + consider(""); + consider(""); + consider(""); + consider(""); + consider("(cleaned.back()))) { + cleaned.pop_back(); + } + return cleaned; +} + static bool has_content_or_tool_calls(const common_chat_msg & msg) { return !msg.content.empty() || !msg.tool_calls.empty(); } @@ -89,8 +120,12 @@ std::vector common_chat_msg_diff::compute_diffs(const comm diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); } if (previous_msg.content != new_msg.content) { - auto & diff = diffs.emplace_back(); - diff.content_delta = string_diff(previous_msg.content, new_msg.content); + auto content_delta = string_diff(previous_msg.content, new_msg.content); + auto cleaned_delta = string_remove_tool_wrappers_suffix(content_delta); + if (!string_strip(cleaned_delta).empty()) { + auto & diff = diffs.emplace_back(); + diff.content_delta = cleaned_delta; + } } if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) { @@ -2116,6 +2151,17 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { "|" // match 5 (function name again) ); + static const std::vector wrapper_open_tags = { + "", + "", + "", + "", + "", + "", + "", + "", + }; + while (auto res = builder.try_find_regex(open_regex)) { const auto & block_start = res->groups[1]; std::string block_end = block_start.empty() ? "" : "```"; @@ -2142,6 +2188,37 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { throw common_chat_msg_partial_exception("failed to parse tool call"); } } else { + auto prelude = res->prelude; + size_t suffix_to_remove = 0; + size_t cursor = prelude.size(); + while (cursor > 0) { + size_t ws = cursor; + while (ws > 0 && std::isspace(static_cast(prelude[ws - 1]))) { + --ws; + } + size_t trimmed_end = ws; + bool matched_wrapper = false; + for (const auto & tag : wrapper_open_tags) { + const size_t tag_len = tag.size(); + if (trimmed_end >= tag_len && prelude.compare(trimmed_end - tag_len, tag_len, tag) == 0) { + matched_wrapper = true; + suffix_to_remove += cursor - (trimmed_end - tag_len); + cursor = trimmed_end - tag_len; + break; + } + } + if (!matched_wrapper) { + break; + } + } + if (suffix_to_remove > 0) { + while (cursor > 0 && std::isspace(static_cast(prelude[cursor - 1]))) { + --cursor; + ++suffix_to_remove; + } + builder.remove_content_suffix(suffix_to_remove); + } + auto function_name = builder.str(res->groups[4]); if (function_name.empty()) { function_name = builder.str(res->groups[5]); @@ -2149,6 +2226,13 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { GGML_ASSERT(!function_name.empty()); close_tag = ""; + bool had_block_start = false; + { + const auto backtick_pos = res->prelude.rfind("```"); + if (backtick_pos != std::string::npos) { + had_block_start = true; + } + } if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { @@ -2156,10 +2240,38 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { } builder.consume_spaces(); builder.consume_literal(close_tag); + + static const std::vector wrapper_close_tags = { + "", + "", + "", + "", + "", + "", + "", + }; + + while (true) { + builder.consume_spaces(); + bool matched_wrapper = false; + for (const auto & wrapper_close : wrapper_close_tags) { + if (builder.try_consume_literal(wrapper_close)) { + matched_wrapper = true; + break; + } + } + if (!matched_wrapper) { + break; + } + } + builder.consume_spaces(); if (!block_end.empty()) { builder.consume_literal(block_end); builder.consume_spaces(); + } else if (had_block_start) { + builder.try_consume_literal("```"); + builder.consume_spaces(); } } } diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index ce0f4b0a2a9f3..3ad422cc469f8 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -826,6 +826,16 @@ static void test_template_output_parsers() { "", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals(message_assist_call_content, + common_chat_parse( + "Hello, world!\nWhat's up?\n" + "\n" + "\n" + "{\"arg1\": 1}\n" + "\n" + "\n", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, common_chat_parse( @@ -840,6 +850,16 @@ static void test_template_output_parsers() { "", /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + "\n" + "{\"arg1\": 1}\n" + "\n" + "\n", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); assert_msg_equals( message_assist_call, common_chat_parse(