Skip to content

Commit c048d4f

Browse files
committed
Fix Hermes tool-call parser leaking wrapper XML
1 parent 5f7e166 commit c048d4f

File tree

4 files changed

+84
-0
lines changed

4 files changed

+84
-0
lines changed

common/chat-parser.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,14 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
391391
void common_chat_msg_parser::clear_tools() {
392392
result_.tool_calls.clear();
393393
}
394+
395+
void common_chat_msg_parser::remove_content_suffix(size_t len) {
396+
if (len == 0 || result_.content.empty()) {
397+
return;
398+
}
399+
if (len >= result_.content.size()) {
400+
result_.content.clear();
401+
return;
402+
}
403+
result_.content.erase(result_.content.size() - len);
404+
}

common/chat-parser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,6 @@ class common_chat_msg_parser {
117117
);
118118

119119
void clear_tools();
120+
121+
void remove_content_suffix(size_t len);
120122
};

common/chat.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,6 +2116,17 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
21162116
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
21172117
);
21182118

2119+
static const std::vector<std::string> wrapper_open_tags = {
2120+
"<tool_call>",
2121+
"<function_call>",
2122+
"<tool>",
2123+
"<tools>",
2124+
"<response>",
2125+
"<json>",
2126+
"<xml>",
2127+
"<JSON>",
2128+
};
2129+
21192130
while (auto res = builder.try_find_regex(open_regex)) {
21202131
const auto & block_start = res->groups[1];
21212132
std::string block_end = block_start.empty() ? "" : "```";
@@ -2142,24 +2153,74 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
21422153
throw common_chat_msg_partial_exception("failed to parse tool call");
21432154
}
21442155
} else {
2156+
auto prelude = res->prelude;
2157+
bool prelude_wrappers_only = true;
2158+
auto trimmed_prelude = string_strip(prelude);
2159+
while (!trimmed_prelude.empty()) {
2160+
bool matched_wrapper = false;
2161+
for (const auto & tag : wrapper_open_tags) {
2162+
if (string_starts_with(trimmed_prelude, tag)) {
2163+
trimmed_prelude = string_strip(trimmed_prelude.substr(tag.size()));
2164+
matched_wrapper = true;
2165+
break;
2166+
}
2167+
}
2168+
if (!matched_wrapper) {
2169+
prelude_wrappers_only = false;
2170+
break;
2171+
}
2172+
}
2173+
if (!prelude.empty() && prelude_wrappers_only) {
2174+
builder.remove_content_suffix(prelude.size());
2175+
}
2176+
21452177
auto function_name = builder.str(res->groups[4]);
21462178
if (function_name.empty()) {
21472179
function_name = builder.str(res->groups[5]);
21482180
}
21492181
GGML_ASSERT(!function_name.empty());
21502182

21512183
close_tag = "</function>";
2184+
const bool had_block_start = res->prelude.find("```") != std::string::npos;
21522185

21532186
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
21542187
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
21552188
throw common_chat_msg_partial_exception("incomplete tool call");
21562189
}
21572190
builder.consume_spaces();
21582191
builder.consume_literal(close_tag);
2192+
2193+
static const std::vector<std::string> wrapper_close_tags = {
2194+
"</tool_call>",
2195+
"</tool>",
2196+
"</tools>",
2197+
"</response>",
2198+
"</json>",
2199+
"</xml>",
2200+
"</JSON>",
2201+
};
2202+
2203+
while (true) {
2204+
builder.consume_spaces();
2205+
bool matched_wrapper = false;
2206+
for (const auto & wrapper_close : wrapper_close_tags) {
2207+
if (builder.try_consume_literal(wrapper_close)) {
2208+
matched_wrapper = true;
2209+
break;
2210+
}
2211+
}
2212+
if (!matched_wrapper) {
2213+
break;
2214+
}
2215+
}
2216+
21592217
builder.consume_spaces();
21602218
if (!block_end.empty()) {
21612219
builder.consume_literal(block_end);
21622220
builder.consume_spaces();
2221+
} else if (had_block_start) {
2222+
builder.try_consume_literal("```");
2223+
builder.consume_spaces();
21632224
}
21642225
}
21652226
}

tests/test-chat.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,16 @@ static void test_template_output_parsers() {
840840
"</function>",
841841
/* is_partial= */ false,
842842
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
843+
assert_msg_equals(
844+
message_assist_call,
845+
common_chat_parse(
846+
"<tool_call>\n"
847+
"<function=special_function>\n"
848+
"{\"arg1\": 1}\n"
849+
"</function>\n"
850+
"</tool_call>\n",
851+
/* is_partial= */ false,
852+
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
843853
assert_msg_equals(
844854
message_assist_call,
845855
common_chat_parse(

0 commit comments

Comments
 (0)