Skip to content

Commit e3c372c

Browse files
author
ochafik
committed
move/fix detection of functionary v3.1 before llama 3.x, fix & test their non-tool mode
1 parent 543b73e commit e3c372c

File tree

2 files changed

+60
-49
lines changed

2 files changed

+60
-49
lines changed

common/chat.cpp

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,15 +1013,17 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
10131013
builder.add_rule("root", string_join(tool_rules, " | "));
10141014
data.additional_stops.push_back("<|eom_id|>");
10151015
});
1016+
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
1017+
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
1018+
: COMMON_CHAT_FORMAT_LLAMA_3_X;
1019+
} else {
1020+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
10161021
}
10171022
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
10181023
{"date_string", format_time(inputs.now, "%d %b %Y")},
10191024
{"tools_in_user_message", false},
10201025
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
10211026
});
1022-
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
1023-
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
1024-
: COMMON_CHAT_FORMAT_LLAMA_3_X;
10251027
return data;
10261028
}
10271029
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
@@ -1296,55 +1298,60 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
12961298
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
12971299
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
12981300
common_chat_params data;
1299-
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
1300-
std::string python_code_argument_name;
1301-
auto has_raw_python = false;
13021301

1303-
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1304-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1305-
std::vector<std::string> tool_rules;
1306-
foreach_function(inputs.tools, [&](const json & tool) {
1307-
const auto & function = tool.at("function");
1308-
const auto & parameters = function.at("parameters");
1309-
std::string name = function.at("name");
1310-
if (name == "python" || name == "ipython") {
1311-
if (!parameters.contains("type")) {
1312-
throw std::runtime_error("Missing type in python tool");
1313-
}
1314-
has_raw_python = true;
1315-
const auto & type = parameters.at("type");
1316-
if (type == "object") {
1317-
auto properties = parameters.at("properties");
1318-
for (auto it = properties.begin(); it != properties.end(); ++it) {
1319-
if (it.value().at("type") == "string") {
1320-
if (!python_code_argument_name.empty()) {
1321-
throw std::runtime_error("Multiple string arguments found in python tool");
1302+
if (!inputs.tools.is_null()) {
1303+
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
1304+
std::string python_code_argument_name;
1305+
auto has_raw_python = false;
1306+
1307+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1308+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1309+
std::vector<std::string> tool_rules;
1310+
foreach_function(inputs.tools, [&](const json & tool) {
1311+
const auto & function = tool.at("function");
1312+
const auto & parameters = function.at("parameters");
1313+
std::string name = function.at("name");
1314+
if (name == "python" || name == "ipython") {
1315+
if (!parameters.contains("type")) {
1316+
throw std::runtime_error("Missing type in python tool");
1317+
}
1318+
has_raw_python = true;
1319+
const auto & type = parameters.at("type");
1320+
if (type == "object") {
1321+
auto properties = parameters.at("properties");
1322+
for (auto it = properties.begin(); it != properties.end(); ++it) {
1323+
if (it.value().at("type") == "string") {
1324+
if (!python_code_argument_name.empty()) {
1325+
throw std::runtime_error("Multiple string arguments found in python tool");
1326+
}
1327+
python_code_argument_name = it.key();
13221328
}
1323-
python_code_argument_name = it.key();
13241329
}
1330+
if (python_code_argument_name.empty()) {
1331+
throw std::runtime_error("No string argument found in python tool");
1332+
}
1333+
} else if (type != "string") {
1334+
throw std::runtime_error("Invalid type in python tool: " + type.dump());
13251335
}
1326-
if (python_code_argument_name.empty()) {
1327-
throw std::runtime_error("No string argument found in python tool");
1328-
}
1329-
} else if (type != "string") {
1330-
throw std::runtime_error("Invalid type in python tool: " + type.dump());
13311336
}
1337+
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
1338+
});
1339+
if (has_raw_python) {
1340+
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
1341+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1342+
data.preserved_tokens.push_back("<|python_tag|>");
13321343
}
1333-
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
1344+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
1345+
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1346+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
13341347
});
1335-
if (has_raw_python) {
1336-
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
1337-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1338-
data.preserved_tokens.push_back("<|python_tag|>");
1339-
}
1340-
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
1341-
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1342-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
1343-
});
1348+
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
1349+
} else {
1350+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1351+
}
13441352

13451353
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
13461354
// TODO: if (has_raw_python)
1347-
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
13481355
return data;
13491356
}
13501357
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
@@ -1656,6 +1663,12 @@ static common_chat_params common_chat_templates_apply_jinja(
16561663
return common_chat_params_init_firefunction_v2(tmpl, params);
16571664
}
16581665

1666+
// Functionary v3.1 (w/ tools)
1667+
if (src.find("<|start_header_id|>") != std::string::npos
1668+
&& src.find("<function=") != std::string::npos) {
1669+
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
1670+
}
1671+
16591672
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
16601673
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
16611674
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
@@ -1667,12 +1680,6 @@ static common_chat_params common_chat_templates_apply_jinja(
16671680
return common_chat_params_init_without_tools(tmpl, params);
16681681
}
16691682

1670-
// Functionary v3.1 (w/ tools)
1671-
if (src.find("<|start_header_id|>") != std::string::npos
1672-
&& src.find("<function=") != std::string::npos) {
1673-
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
1674-
}
1675-
16761683
// Mistral Nemo (w/ tools)
16771684
if (src.find("[TOOL_CALLS]") != std::string::npos) {
16781685
return common_chat_params_init_mistral_nemo(tmpl, params);

tests/test-chat.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,8 @@ static void test_template_output_parsers() {
815815
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
816816

817817
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
818+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
819+
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
818820

819821
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
820822
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
@@ -825,7 +827,9 @@ static void test_template_output_parsers() {
825827
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
826828

827829
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
828-
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
830+
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
831+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
832+
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
829833

830834
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
831835
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,

0 commit comments

Comments
 (0)