Skip to content

Commit 387611a

Browse files
author
ochafik
committed
Merge branch 'date' into tool-diffs
2 parents 90789cd + e3c372c commit 387611a

File tree

3 files changed

+62
-50
lines changed

3 files changed

+62
-50
lines changed

common/chat.cpp

Lines changed: 56 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,17 +1092,19 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
10921092
}
10931093
// Allow a few empty lines on top of the usual constrained json schema space rule.
10941094
builder.add_rule("root", string_join(tool_rules, " | "));
1095+
data.additional_stops.push_back("<|eom_id|>");
10951096
});
1096-
data.additional_stops.push_back("<|eom_id|>");
1097+
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
1098+
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
1099+
: COMMON_CHAT_FORMAT_LLAMA_3_X;
1100+
} else {
1101+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
10971102
}
10981103
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
10991104
{"date_string", format_time(inputs.now, "%d %b %Y")},
11001105
{"tools_in_user_message", false},
11011106
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
11021107
});
1103-
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
1104-
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
1105-
: COMMON_CHAT_FORMAT_LLAMA_3_X;
11061108
return data;
11071109
}
11081110
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
@@ -1375,55 +1377,60 @@ static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder)
13751377
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
13761378
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
13771379
common_chat_params data;
1378-
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
1379-
std::string python_code_argument_name;
1380-
auto has_raw_python = false;
13811380

1382-
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1383-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1384-
std::vector<std::string> tool_rules;
1385-
foreach_function(inputs.tools, [&](const json & tool) {
1386-
const auto & function = tool.at("function");
1387-
const auto & parameters = function.at("parameters");
1388-
std::string name = function.at("name");
1389-
if (name == "python" || name == "ipython") {
1390-
if (!parameters.contains("type")) {
1391-
throw std::runtime_error("Missing type in python tool");
1392-
}
1393-
has_raw_python = true;
1394-
const auto & type = parameters.at("type");
1395-
if (type == "object") {
1396-
auto properties = parameters.at("properties");
1397-
for (auto it = properties.begin(); it != properties.end(); ++it) {
1398-
if (it.value().at("type") == "string") {
1399-
if (!python_code_argument_name.empty()) {
1400-
throw std::runtime_error("Multiple string arguments found in python tool");
1381+
if (!inputs.tools.is_null()) {
1382+
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
1383+
std::string python_code_argument_name;
1384+
auto has_raw_python = false;
1385+
1386+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1387+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1388+
std::vector<std::string> tool_rules;
1389+
foreach_function(inputs.tools, [&](const json & tool) {
1390+
const auto & function = tool.at("function");
1391+
const auto & parameters = function.at("parameters");
1392+
std::string name = function.at("name");
1393+
if (name == "python" || name == "ipython") {
1394+
if (!parameters.contains("type")) {
1395+
throw std::runtime_error("Missing type in python tool");
1396+
}
1397+
has_raw_python = true;
1398+
const auto & type = parameters.at("type");
1399+
if (type == "object") {
1400+
auto properties = parameters.at("properties");
1401+
for (auto it = properties.begin(); it != properties.end(); ++it) {
1402+
if (it.value().at("type") == "string") {
1403+
if (!python_code_argument_name.empty()) {
1404+
throw std::runtime_error("Multiple string arguments found in python tool");
1405+
}
1406+
python_code_argument_name = it.key();
14011407
}
1402-
python_code_argument_name = it.key();
14031408
}
1409+
if (python_code_argument_name.empty()) {
1410+
throw std::runtime_error("No string argument found in python tool");
1411+
}
1412+
} else if (type != "string") {
1413+
throw std::runtime_error("Invalid type in python tool: " + type.dump());
14041414
}
1405-
if (python_code_argument_name.empty()) {
1406-
throw std::runtime_error("No string argument found in python tool");
1407-
}
1408-
} else if (type != "string") {
1409-
throw std::runtime_error("Invalid type in python tool: " + type.dump());
14101415
}
1416+
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
1417+
});
1418+
if (has_raw_python) {
1419+
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
1420+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1421+
data.preserved_tokens.push_back("<|python_tag|>");
14111422
}
1412-
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
1423+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
1424+
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1425+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
14131426
});
1414-
if (has_raw_python) {
1415-
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
1416-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1417-
data.preserved_tokens.push_back("<|python_tag|>");
1418-
}
1419-
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
1420-
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1421-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
1422-
});
1427+
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
1428+
} else {
1429+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1430+
}
14231431

14241432
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
14251433
// TODO: if (has_raw_python)
1426-
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
14271434
return data;
14281435
}
14291436
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
@@ -1714,6 +1721,12 @@ static common_chat_params common_chat_templates_apply_jinja(
17141721
return common_chat_params_init_firefunction_v2(tmpl, params);
17151722
}
17161723

1724+
// Functionary v3.1 (w/ tools)
1725+
if (src.find("<|start_header_id|>") != std::string::npos
1726+
&& src.find("<function=") != std::string::npos) {
1727+
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
1728+
}
1729+
17171730
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
17181731
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
17191732
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
@@ -1725,12 +1738,6 @@ static common_chat_params common_chat_templates_apply_jinja(
17251738
return common_chat_params_init_without_tools(tmpl, params);
17261739
}
17271740

1728-
// Functionary v3.1 (w/ tools)
1729-
if (src.find("<|start_header_id|>") != std::string::npos
1730-
&& src.find("<function=") != std::string::npos) {
1731-
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
1732-
}
1733-
17341741
// Mistral Nemo (w/ tools)
17351742
if (src.find("[TOOL_CALLS]") != std::string::npos) {
17361743
return common_chat_params_init_mistral_nemo(tmpl, params);

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "common.h"
66
#include <functional>
7+
#include <chrono>
78
#include <string>
89
#include <vector>
910

tests/test-chat.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,8 @@ static void test_template_output_parsers() {
10511051
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
10521052

10531053
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1054+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
1055+
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
10541056

10551057
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
10561058
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
@@ -1061,7 +1063,9 @@ static void test_template_output_parsers() {
10611063
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
10621064

10631065
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
1064-
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1066+
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1067+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
1068+
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
10651069

10661070
for (auto is_partial : { false, true }) {
10671071
assert_equals(

0 commit comments

Comments
 (0)