Skip to content

Commit f738989

Browse files
authored
chat : fix multiple tool_calls on hermes-2-pro (#14962)
1 parent 4cb208c commit f738989

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

common/chat.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16461646
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
16471647
);
16481648

1649-
if (auto res = builder.try_find_regex(open_regex)) {
1649+
while (auto res = builder.try_find_regex(open_regex)) {
16501650
const auto & block_start = res->groups[1];
16511651
std::string block_end = block_start.empty() ? "" : "```";
16521652

@@ -1668,7 +1668,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16681668
builder.consume_literal(block_end);
16691669
builder.consume_spaces();
16701670
}
1671-
builder.add_content(builder.consume_rest());
16721671
} else {
16731672
throw common_chat_msg_partial_exception("failed to parse tool call");
16741673
}
@@ -1693,11 +1692,10 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16931692
builder.consume_spaces();
16941693
}
16951694
}
1696-
builder.add_content(builder.consume_rest());
16971695
}
1698-
} else {
1699-
builder.add_content(builder.consume_rest());
17001696
}
1697+
1698+
builder.add_content(builder.consume_rest());
17011699
}
17021700

17031701
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {

tests/test-chat.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,33 @@ static void test_template_output_parsers() {
953953
/* is_partial= */ false,
954954
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
955955

956+
// Test multiple tool calls
957+
common_chat_msg message_assist_multiple_calls;
958+
message_assist_multiple_calls.role = "assistant";
959+
message_assist_multiple_calls.content = "";
960+
message_assist_multiple_calls.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""});
961+
message_assist_multiple_calls.tool_calls.push_back({"python", "{\"code\":\"print('hello')\"}", ""});
962+
963+
assert_msg_equals(
964+
message_assist_multiple_calls,
965+
common_chat_parse(
966+
"<tool_call>\n"
967+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
968+
"</tool_call>\n"
969+
"<tool_call>\n"
970+
"{\"name\": \"python\", \"arguments\": {\"code\":\"print('hello')\"}}\n"
971+
"</tool_call>",
972+
/* is_partial= */ false,
973+
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
974+
975+
assert_msg_equals(
976+
message_assist_multiple_calls,
977+
common_chat_parse(
978+
"<function=special_function>{\"arg1\": 1}</function>\n"
979+
"<function=python>{\"code\":\"print('hello')\"}</function>",
980+
/* is_partial= */ false,
981+
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
982+
956983
assert_msg_equals(
957984
simple_assist_msg(
958985
"This is not a tool call:",
@@ -1039,6 +1066,22 @@ static void test_template_output_parsers() {
10391066
"<tool_call>\n"
10401067
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
10411068
"</tool_call>");
1069+
1070+
// Test multiple tool calls with template
1071+
common_chat_msg message_assist_multiple_calls_template;
1072+
message_assist_multiple_calls_template.role = "assistant";
1073+
message_assist_multiple_calls_template.content = "";
1074+
message_assist_multiple_calls_template.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""});
1075+
message_assist_multiple_calls_template.tool_calls.push_back({"python", "{\"code\":\"print('test')\"}", ""});
1076+
1077+
test_templates(tmpls.get(), end_tokens, message_assist_multiple_calls_template, tools,
1078+
"<tool_call>\n"
1079+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
1080+
"</tool_call>\n"
1081+
"<tool_call>\n"
1082+
"{\"name\": \"python\", \"arguments\": {\"code\":\"print('test')\"}}\n"
1083+
"</tool_call>");
1084+
10421085
test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools,
10431086
"<tool_call>\n"
10441087
"{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n"

0 commit comments

Comments
 (0)