Skip to content

Commit 64ea080

Browse files
author
ochafik
committed
fix functionary v3.2 raw python!
1 parent d6e680a commit 64ea080

File tree

4 files changed

+100
-49
lines changed

4 files changed

+100
-49
lines changed

common/chat-parser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ std::string common_chat_msg_parser::consume_rest() {
153153
}
154154

155155
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
156-
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex) {
157-
auto m = regex.search(input_, pos_);
156+
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
157+
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
158158
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
159159
return std::nullopt;
160160
}

common/chat-parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class common_chat_msg_parser {
2525
public:
2626
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_reasoning_syntax & reasoning_syntax);
2727
const std::string & input() const { return input_; }
28+
size_t pos() const { return pos_; }
2829
const std::string & healing_marker() const { return healing_marker_; }
2930
const bool & is_partial() const { return is_partial_; }
3031
const common_chat_msg & result() const { return result_; }
@@ -79,7 +80,7 @@ class common_chat_msg_parser {
7980
std::vector<common_string_range> groups;
8081
};
8182

82-
std::optional<find_regex_result> try_find_regex(const common_regex & regex);
83+
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
8384

8485
struct consume_regex_result {
8586
std::vector<common_string_range> groups;

common/chat.cpp

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -561,23 +561,51 @@ static void parse_json_tool_calls(
561561
const common_regex & close_regex,
562562
const std::optional<common_regex> & block_close,
563563
bool allow_raw_python = false,
564-
const std::function<bool(const std::string & name)> & is_function = nullptr) {
564+
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name = nullptr) {
565565

566566
auto parse_tool_calls = [&]() {
567+
size_t from = std::string::npos;
567568
while (true) {
568-
if (auto res = builder.try_find_regex(function_regex)) {
569-
GGML_ASSERT(res->groups.size() == 2);
570-
auto name = builder.str(res->groups[1]);
569+
if (auto res = builder.try_find_regex(function_regex, from)) {
570+
std::string name;
571+
if (get_function_name) {
572+
name = get_function_name(*res);
573+
} else {
574+
GGML_ASSERT(res->groups.size() == 2);
575+
name = builder.str(res->groups[1]);
576+
}
577+
if (name.empty()) {
578+
// get_function_name signalled us that we should skip this match and treat it as content.
579+
from = res->groups[0].begin + 1;
580+
continue;
581+
} else {
582+
from = std::string::npos;
583+
}
571584
builder.add_content(res->prelude);
572-
if (is_function && !is_function(name)) {
585+
if (auto partial = builder.try_consume_json({{}})) {
586+
std::string arguments = partial->json.dump();
587+
if (!builder.add_tool_call(name, "", arguments, partial->healing_marker)) {
588+
builder.incomplete("incomplete tool call");
589+
}
590+
builder.consume_regex(close_regex);
591+
} else if (name == "python" && allow_raw_python) {
592+
auto code = builder.consume_rest();
593+
std::string arguments;
594+
common_healing_marker healing_marker;
595+
if (builder.is_partial()) {
596+
healing_marker.json_dump_marker = healing_marker.marker = builder.healing_marker();
597+
arguments = (json {{"code", code + healing_marker.marker}}).dump();
598+
} else {
599+
arguments = (json {{"code", code}}).dump();
600+
}
601+
if (!builder.add_tool_call(name, "", arguments, healing_marker)) {
602+
builder.incomplete("incomplete tool call");
603+
}
573604
return;
574-
}
575-
auto partial = builder.consume_json({{}});
576-
std::string arguments = partial.json.dump();
577-
if (!builder.add_tool_call(name, "", arguments, partial.healing_marker)) {
605+
} else {
578606
builder.incomplete("incomplete tool call");
607+
return;
579608
}
580-
builder.consume_regex(close_regex);
581609
} else {
582610
break;
583611
}
@@ -863,7 +891,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
863891
if (!inputs.parallel_tool_calls) {
864892
schema["maxItems"] = 1;
865893
}
866-
builder.add_rule("root",
894+
builder.add_rule("root",
867895
std::string(data.thinking_forced_open ? "\"<|END_THINKING|>\" space " : "") +
868896
"\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
869897
});
@@ -1193,6 +1221,7 @@ static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder)
11931221
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
11941222
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
11951223
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
1224+
// If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
11961225
common_chat_params data;
11971226
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
11981227
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
@@ -1206,24 +1235,17 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
12061235
std::string name = function.at("name");
12071236
auto parameters = function.at("parameters");
12081237
builder.resolve_refs(parameters);
1238+
std::string args_pattern = "[\\s\\S]*";
12091239
auto args_rule = builder.add_schema(name + "-args", parameters);
1240+
if (name == "python") {
1241+
args_pattern = "\\{" + args_pattern;
1242+
args_rule = builder.add_rule(name + "-maybe-raw-args", args_rule + " | [^{] .*");
1243+
}
12101244
first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
12111245
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
12121246
data.grammar_triggers.push_back({
12131247
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1214-
"(" + regex_escape(name + "\n") + ")[\\s\\S]*",
1215-
});
1216-
data.grammar_triggers.push_back({
1217-
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1218-
"(" + regex_escape("assistant<|end_header_id|>\n" + name + "\n") + ")[\\s\\S]*",
1219-
});
1220-
data.grammar_triggers.push_back({
1221-
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1222-
regex_escape(">>>" + name + "\n"),
1223-
});
1224-
data.grammar_triggers.push_back({
1225-
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1226-
">>>assistant<|end_header_id|>\n" + name,
1248+
"((?:[\\s\\S]*?>>>)?" + regex_escape(name) + "\n)" + args_pattern,
12271249
});
12281250
});
12291251
data.preserved_tokens = {
@@ -1242,30 +1264,27 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
12421264
return data;
12431265
}
12441266
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
1245-
static const common_regex function_regex(R"(>>>(\w+)\n)");
1267+
static const common_regex function_regex(R"((>>>)?(\w+\n\{|python\n|all\n))");
12461268
static const common_regex close_regex(R"(\s*)", /* at_start= */ true);
12471269

1248-
static const common_regex initial_function_regex(R"((?:assistant<\|end_header_id\|>\n)?(\w+)\n\{\s*")", /* at_start= */ true);
1249-
1250-
if (auto res = builder.try_consume_regex(initial_function_regex)) {
1251-
auto name = builder.str(res->groups[1]);
1252-
if (name == "all") {
1253-
builder.move_to(res->groups[1].end + 1);
1254-
builder.add_content(builder.consume_rest());
1255-
return;
1256-
}
1257-
// Move to just after the function name + newline
1258-
builder.move_to(res->groups[1].end + 1);
1259-
auto args = builder.consume_json({{}});
1260-
if (!builder.add_tool_call(name, "", args.json.dump(), args.healing_marker)) {
1261-
builder.incomplete("Incomplete tool call");
1262-
}
1263-
builder.consume_spaces();
1264-
}
1265-
12661270
parse_json_tool_calls(builder, std::nullopt, function_regex, close_regex, std::nullopt, /* allow_raw_python= */ true,
1267-
/* is_function= */ [&](const auto & name) {
1268-
return name != "all";
1271+
/* get_function_name= */ [&](const auto & res) -> std::string {
1272+
auto at_start = res.groups[0].begin == 0;
1273+
if (at_start != res.groups[1].empty()) {
1274+
// Only accept >>> as a match if it's not at the beginning.
1275+
return "";
1276+
}
1277+
auto name = builder.str(res.groups[2]);
1278+
if (!name.empty() && name.back() == '{') {
1279+
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
1280+
builder.move_back(1);
1281+
}
1282+
auto idx = name.find_last_not_of("\n{");
1283+
name = name.substr(0, idx + 1);
1284+
if (at_start && name == "all") {
1285+
return "";
1286+
}
1287+
return name;
12691288
});
12701289
}
12711290

tests/test-chat.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,9 @@ const std::vector<common_chat_tool_call> tool_calls_idx {
402402
const std::vector<common_chat_tool_call> tool_calls_id {
403403
{ "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
404404
};
405+
const std::vector<common_chat_tool_call> tool_calls_python {
406+
{ "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" },
407+
};
405408

406409
const common_chat_msg message_assist_empty {
407410
"assistant",
@@ -488,7 +491,7 @@ const common_chat_msg message_assist_call_python {
488491
"assistant",
489492
"",
490493
/* .content_parts = */ {},
491-
{ { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
494+
tool_calls_python,
492495
/* .reasoning_content = */ "",
493496
/* .tool_name = */ "",
494497
/* .tool_call_id = */ "",
@@ -977,11 +980,39 @@ static void test_template_output_parsers() {
977980
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
978981
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
979982

983+
assert_msg_equals(
984+
common_chat_msg {
985+
"assistant",
986+
"Hello, world!\nnono\nWhat's up?",
987+
/* .content_parts = */ {},
988+
/* .tool_calls = */ tool_calls,
989+
/* .reasoning_content = */ "",
990+
/* .tool_name = */ "",
991+
/* .tool_call_id = */ ""
992+
},
993+
common_chat_parse(
994+
"all\n"
995+
"Hello, world!\n"
996+
"nono\n"
997+
"What's up?\n"
998+
">>>special_function\n"
999+
"{\"arg1\": 1}\n",
1000+
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2));
1001+
assert_msg_equals(message_assist_call_python,
1002+
common_chat_parse(
1003+
"python\n"
1004+
"print('hey')",
1005+
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2));
9801006
assert_msg_equals(message_assist_call,
9811007
common_chat_parse(
9821008
"special_function\n"
9831009
"{\"arg1\": 1} \n ",
9841010
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2));
1011+
assert_msg_equals(message_assist,
1012+
common_chat_parse(
1013+
"all\n"
1014+
"Hello, world!\nWhat's up?",
1015+
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2));
9851016

9861017
test_templates(tmpls.get(), end_tokens, message_assist, {},
9871018
"all\n"

0 commit comments

Comments
 (0)