Skip to content

Commit cd63ba4

Browse files
author
ochafik
committed
beef up test-chat-handler w/ delta expectations
1 parent ba10b47 commit cd63ba4

File tree

3 files changed

+209
-284
lines changed

3 files changed

+209
-284
lines changed

common/chat-handler.cpp

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -541,31 +541,35 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
541541
common_chat_data data;
542542

543543
data.grammar_lazy = params.tool_choice != "required";
544-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
545-
std::vector<std::string> first_tool_rules;
546-
std::vector<std::string> subsequent_tool_rules;
547-
foreach_function(params.tools, [&](const json & tool) {
548-
const auto & function = tool["function"];
549-
std::string name = function["name"];
550-
auto parameters = function["parameters"];
551-
auto args_rule = builder.add_schema(name + "-args", parameters);
552-
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
553-
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
554-
data.grammar_triggers.push_back({name, /* .at_start = */ true});
555-
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
556-
});
557-
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
558-
if (params.parallel_tool_calls) {
559-
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
560-
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
561-
} else {
562-
builder.add_rule("root", first_rule);
563-
}
544+
if (!params.tools.is_null() && !params.tools.empty()) {
545+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
546+
std::vector<std::string> first_tool_rules;
547+
std::vector<std::string> subsequent_tool_rules;
548+
foreach_function(params.tools, [&](const json & tool) {
549+
const auto & function = tool["function"];
550+
std::string name = function["name"];
551+
auto parameters = function["parameters"];
552+
auto args_rule = builder.add_schema(name + "-args", parameters);
553+
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
554+
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
555+
data.grammar_triggers.push_back({name, /* .at_start = */ true});
556+
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
557+
});
558+
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
559+
if (params.parallel_tool_calls) {
560+
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
561+
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
562+
} else {
563+
builder.add_rule("root", first_rule);
564+
}
564565

565-
}, grammar_options);
566+
}, grammar_options);
567+
data.format = "functionary v3.2 tool calls";
568+
} else {
569+
data.format = "functionary v3.2 content-only";
570+
}
566571

567572
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
568-
data.format = "functionary v3.2 tool calls";
569573
data.parser = [params](const std::string & input) {
570574
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
571575
static std::regex close_regex(R"($|(?=>>>))");
@@ -763,21 +767,24 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
763767
}
764768

765769
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
766-
if (params.tools.is_null() || params.tool_choice == "none") {
767-
return common_chat_init_without_tools(tmpl, params);
768-
}
769-
770-
if (!params.grammar.empty()) {
770+
auto has_tools = params.tools.is_null() || params.tool_choice == "none";
771+
if (has_tools && !params.grammar.empty()) {
771772
throw std::runtime_error("Cannot specify grammar with tools");
772773
}
773774

774775
const auto & src = tmpl.source();
775-
if (src.find("<tool_call>") != std::string::npos) {
776-
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
777-
}
778776
if (src.find(">>>all") != std::string::npos) {
777+
// Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
779778
return common_chat_init_functionary_v3_2_tool_call(tmpl, params);
780779
}
780+
781+
if (has_tools) {
782+
return common_chat_init_without_tools(tmpl, params);
783+
}
784+
785+
if (src.find("<tool_call>") != std::string::npos) {
786+
return common_chat_init_hermes_2_pro_tool_call(tmpl, params);
787+
}
781788
if (src.find("<|start_header_id|>") != std::string::npos
782789
&& src.find("<function=") != std::string::npos) {
783790
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);

common/chat-template.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,13 @@ class chat_template {
6565
try_raw_render({
6666
{{"role", "user"}, {"content", "Hey"}},
6767
}, {
68-
{{"name", "some_tool"}, {"parameters", {{"type", "string"}}}},
68+
{
69+
{"type", "function"},
70+
{"function", {
71+
{"name", "some_tool"},
72+
{"parameters", {{"type", "string"}}},
73+
}},
74+
},
6975
}, false).find("some_tool") != std::string::npos;
7076

7177
requires_object_arguments_ =

0 commit comments

Comments
 (0)