Skip to content

Commit a818114

Browse files
author
ochafik
committed
make functionary v3.2 parsing more strict (differentiate first match from others)
1 parent e9540ad commit a818114

File tree

4 files changed

+60
-23
lines changed

4 files changed

+60
-23
lines changed

common/chat-parser.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,14 @@ std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser:
193193
return find_regex_result{prelude, m.groups};
194194
}
195195

196-
common_chat_msg_parser::consume_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
196+
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
197197
if (auto result = try_consume_regex(regex)) {
198198
return *result;
199199
}
200200
throw common_chat_msg_partial_exception(regex.str());
201201
}
202202

203-
std::optional<common_chat_msg_parser::consume_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
203+
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
204204
auto m = regex.search(input_, pos_);
205205
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
206206
return std::nullopt;
@@ -217,7 +217,10 @@ std::optional<common_chat_msg_parser::consume_regex_result> common_chat_msg_pars
217217
}
218218
pos_ = m.groups[0].end;
219219

220-
return consume_regex_result{m.groups};
220+
return find_regex_result {
221+
/* .prelude = */ "",
222+
m.groups,
223+
};
221224
}
222225

223226
std::optional<common_json> common_chat_msg_parser::try_consume_json() {

common/chat-parser.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,9 @@ class common_chat_msg_parser {
8383

8484
std::optional<find_regex_result> try_find_literal(const std::string & literal);
8585

86-
struct consume_regex_result {
87-
std::vector<common_string_range> groups;
88-
};
89-
consume_regex_result consume_regex(const common_regex & regex);
86+
find_regex_result consume_regex(const common_regex & regex);
9087

91-
std::optional<consume_regex_result> try_consume_regex(const common_regex & regex);
88+
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
9289

9390
std::optional<common_json> try_consume_json();
9491
common_json consume_json();

common/chat.cpp

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -570,23 +570,31 @@ static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, cons
570570
static void parse_json_tool_calls(
571571
common_chat_msg_parser & builder,
572572
const std::optional<common_regex> & block_open,
573-
const common_regex & function_regex,
573+
const std::optional<common_regex> & function_regex_start_only,
574+
const std::optional<common_regex> & function_regex,
574575
const common_regex & close_regex,
575576
const std::optional<common_regex> & block_close,
576577
bool allow_raw_python = false,
577578
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name = nullptr) {
578579

579580
auto parse_tool_calls = [&]() {
580581
size_t from = std::string::npos;
582+
auto first = true;
581583
while (true) {
582-
if (auto res = builder.try_find_regex(function_regex, from)) {
584+
auto res = function_regex_start_only && first
585+
? builder.try_consume_regex(*function_regex_start_only)
586+
: function_regex
587+
? builder.try_find_regex(*function_regex, from)
588+
: std::nullopt;
589+
if (res) {
583590
std::string name;
584591
if (get_function_name) {
585592
name = get_function_name(*res);
586593
} else {
587594
GGML_ASSERT(res->groups.size() == 2);
588595
name = builder.str(res->groups[1]);
589596
}
597+
first = false;
590598
if (name.empty()) {
591599
// get_function_name signalled us that we should skip this match and treat it as content.
592600
from = res->groups[0].begin + 1;
@@ -1055,12 +1063,12 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
10551063
static const common_regex function_regex(
10561064
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
10571065
static const common_regex close_regex("\\}\\s*");
1058-
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
1059-
1066+
10601067
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
10611068
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
1062-
1069+
10631070
if (with_builtin_tools) {
1071+
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
10641072
if (auto res = builder.try_find_regex(builtin_call_regex)) {
10651073
builder.add_content(res->prelude);
10661074

@@ -1094,7 +1102,13 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
10941102
return;
10951103
}
10961104
}
1097-
parse_json_tool_calls(builder, std::nullopt, function_regex, close_regex, std::nullopt);
1105+
parse_json_tool_calls(
1106+
builder,
1107+
/* block_open= */ std::nullopt,
1108+
/* function_regex_start_only= */ function_regex,
1109+
/* function_regex= */ std::nullopt,
1110+
close_regex,
1111+
std::nullopt);
10981112

10991113
}
11001114

@@ -1175,7 +1189,13 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
11751189
static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n");
11761190
static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
11771191

1178-
parse_json_tool_calls(builder, tool_calls_begin, function_regex, close_regex, tool_calls_end);
1192+
parse_json_tool_calls(
1193+
builder,
1194+
/* block_open= */ tool_calls_begin,
1195+
/* function_regex_start_only= */ std::nullopt,
1196+
function_regex,
1197+
close_regex,
1198+
tool_calls_end);
11791199
}
11801200

11811201
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1278,17 +1298,21 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
12781298
return data;
12791299
}
12801300
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
1281-
static const common_regex function_regex(R"((>>>)?(\w+\n\{|python\n|all\n))");
1301+
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
1302+
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
12821303
static const common_regex close_regex(R"(\s*)");
12831304

1284-
parse_json_tool_calls(builder, std::nullopt, function_regex, close_regex, std::nullopt, /* allow_raw_python= */ true,
1305+
parse_json_tool_calls(
1306+
builder,
1307+
std::nullopt,
1308+
function_regex_start_only,
1309+
function_regex,
1310+
close_regex,
1311+
std::nullopt,
1312+
/* allow_raw_python= */ true,
12851313
/* get_function_name= */ [&](const auto & res) -> std::string {
12861314
auto at_start = res.groups[0].begin == 0;
1287-
if (at_start != res.groups[1].empty()) {
1288-
// Only accept >>> as a match if it's not at the beginning.
1289-
return "";
1290-
}
1291-
auto name = builder.str(res.groups[2]);
1315+
auto name = builder.str(res.groups[1]);
12921316
if (!name.empty() && name.back() == '{') {
12931317
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
12941318
builder.move_back(1);
@@ -1370,7 +1394,13 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
13701394
static const common_regex function_regex(R"(<function=(\w+)>)");
13711395
static const common_regex close_regex(R"(</function>)");
13721396

1373-
parse_json_tool_calls(builder, std::nullopt, function_regex, close_regex, std::nullopt);
1397+
parse_json_tool_calls(
1398+
builder,
1399+
/* block_open= */ std::nullopt,
1400+
/* function_regex_start_only= */ std::nullopt,
1401+
function_regex,
1402+
close_regex,
1403+
std::nullopt);
13741404
}
13751405

13761406
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {

tests/test-chat.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,13 @@ static void test_template_output_parsers() {
11931193
inputs_tools_builtin)
11941194
.format);
11951195

1196+
assert_equals(
1197+
message_assist_call,
1198+
common_chat_parse(
1199+
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
1200+
/* is_partial= */ false,
1201+
{COMMON_CHAT_FORMAT_LLAMA_3_X}));
1202+
11961203
// test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
11971204
test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
11981205
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");

0 commit comments

Comments
 (0)