Skip to content

Commit f7078ca

Browse files
author
ochafik
committed
tool-call: fix functionary v3.1 required test
1 parent 5ec4c5e commit f7078ca

File tree

3 files changed

+52
-32
lines changed

3 files changed

+52
-32
lines changed

common/chat-handler.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
102102

103103
json arguments;
104104
if (!parse_json(it, end, arguments)) {
105+
if (name == "python" && std::regex_match("", close_regex)) {
106+
std::string src(it, end);
107+
result.tool_calls.push_back({name, src, /* id= */ ""});
108+
break;
109+
}
105110
throw std::runtime_error("Failed to parse json tool call arguments");
106111
}
107112
if (!std::regex_search(it, end, match, close_regex)) {
@@ -390,11 +395,11 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
390395
static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
391396
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
392397
common_chat_data data;
393-
398+
auto has_python = false;
399+
394400
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
395401
std::vector<std::string> tool_rules;
396402

397-
auto has_python = false;
398403

399404
for (const auto & tool : params.tools) {
400405
if (!tool.contains("type")) {
@@ -433,7 +438,7 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
433438
}
434439
}
435440

436-
if (has_python) {
441+
if (has_python && uses_python_tag) {
437442
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
438443
if (params.tool_choice != "required") {
439444
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
@@ -453,8 +458,8 @@ static common_chat_data common_chat_init_llama_3_tool_calls(const common_chat_te
453458
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true, {
454459
{"builtin_tools", builtin_tools},
455460
});
456-
data.parser = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
457-
if (uses_python_tag) {
461+
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python, uses_python_tag](const std::string & input) -> common_chat_msg {
462+
if (has_python && uses_python_tag) {
458463
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
459464
std::smatch match;
460465
if (std::regex_search(input, match, python_tag_regex)) {
@@ -521,10 +526,10 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
521526
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
522527
common_chat_data data;
523528

529+
auto has_python = false;
524530
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
525531
std::vector<std::string> first_tool_rules;
526532
std::vector<std::string> subsequent_tool_rules;
527-
auto has_python = false;
528533
for (const auto & tool : params.tools) {
529534
if (!tool.contains("type")) {
530535
continue;
@@ -544,7 +549,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
544549
}
545550
}
546551
}
547-
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
552+
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
548553
// Note: if there's a python rule, it needs to come last.
549554
auto python_rule = builder.add_rule("python-call", "\"python\\n\" .*");
550555
if (has_python && params.tool_choice != "required") {
@@ -553,14 +558,14 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
553558
}
554559
if (params.parallel_tool_calls) {
555560
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
556-
builder.add_rule("root", python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
561+
builder.add_rule("root", first_rule.empty() ? python_rule : python_rule + " | " + first_rule + " (" + subsequent_rule + ")*" + (has_python ? " ( \">>>\\n\" " + python_rule + " )?" : ""));
557562
} else {
558-
builder.add_rule("root", first_rule + (has_python ? " | " + python_rule : ""));
563+
builder.add_rule("root", first_rule.empty() ? python_rule : first_rule + (has_python ? " | " + python_rule : ""));
559564
}
560565
}, grammar_options);
561566

562567
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, /* add_generation_prompt= */ true);
563-
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
568+
data.parser = std::make_unique<monolithic_chat_parser>([params, has_python](const std::string & input) -> common_chat_msg {
564569
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
565570
static std::regex close_regex(R"($|(?=>>>))");
566571
return parse_json_tool_calls(params.tools, input, function_regex, close_regex, /* check_names= */ true);
@@ -723,7 +728,7 @@ static common_chat_data common_chat_init_without_tools(const common_chat_templat
723728
}
724729

725730
common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params) {
726-
if (params.tools.is_null()) {
731+
if (params.tools.is_null() || params.tool_choice == "none") {
727732
return common_chat_init_without_tools(tmpl, params);
728733
}
729734

examples/server/server.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3788,11 +3788,14 @@ int main(int argc, char ** argv) {
37883788
/* .tools = */ json_value(data, "tools", json()),
37893789
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),
37903790
/* .json_schema = */ json_value(data, "json_schema", json()),
3791-
/* .parallel_tool_calls = */ json_value(data, "json_schema", true),
3792-
/* .stream = */ json_value(data, "json_schema", false),
3791+
/* .parallel_tool_calls = */ json_value(data, "parallel_tool_calls", false),
3792+
/* .stream = */ json_value(data, "stream", false),
37933793
/* .grammar = */ json_value(data, "grammar", std::string("")),
37943794
});
37953795
if (data.contains("grammar")) {
3796+
if (!chat_data.grammar.empty()) {
3797+
throw std::runtime_error("Cannot provide grammar and tools");
3798+
}
37963799
chat_data.grammar = data.at("grammar");
37973800
}
37983801
} else {

examples/server/tests/unit/test_chat_completion.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -226,23 +226,31 @@ def test_chat_completion_with_timings_per_token():
226226
}
227227

228228

229-
@pytest.mark.parametrize("template_name,n_predict,tool,argument_key", [
230-
("meetkai-functionary-medium-v3.1", 128, TEST_TOOL, "success"),
231-
("meetkai-functionary-medium-v3.1", 128, PYTHON_TOOL, "code"),
232-
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, "success"),
233-
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, "code"),
234-
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, "success"),
235-
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, "code"),
236-
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, "success"),
237-
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, "code"),
238-
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, "success"),
239-
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, "code"),
240-
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, "success"),
241-
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, "code"),
242-
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, "success"),
243-
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, "code"),
229+
@pytest.mark.parametrize("template_name,tool,argument_key", [
230+
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
231+
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, None),
232+
("meetkai-functionary-medium-v3.1", CODE_INTEPRETER_TOOL, None),
233+
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
234+
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, None),
235+
("meetkai-functionary-medium-v3.2", CODE_INTEPRETER_TOOL, None),
236+
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
237+
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, None),
238+
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", CODE_INTEPRETER_TOOL, None),
239+
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
240+
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, None),
241+
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", CODE_INTEPRETER_TOOL, None),
242+
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
243+
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, None),
244+
("meta-llama-Meta-Llama-3.1-8B-Instruct", CODE_INTEPRETER_TOOL, None),
245+
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
246+
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, None),
247+
# # ("meta-llama-Llama-3.2-3B-Instruct", CODE_INTEPRETER_TOOL, None),
248+
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
249+
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, None),
250+
("mistralai-Mistral-Nemo-Instruct-2407", CODE_INTEPRETER_TOOL, None),
244251
])
245-
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, argument_key: str):
252+
def test_completion_with_required_tool(template_name: str, tool: dict, argument_key: str | None):
253+
n_predict = 512
246254
global server
247255
# server = ServerPreset.stories15m_moe()
248256
server.jinja = True
@@ -267,9 +275,13 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
267275
tool_calls = choice["message"].get("tool_calls")
268276
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
269277
tool_call = tool_calls[0]
270-
assert tool["function"]["name"] == tool_call["function"]["name"]
271-
actual_arguments = json.loads(tool_call["function"]["arguments"])
272-
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
278+
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
279+
assert expected_function_name == tool_call["function"]["name"]
280+
actual_arguments = tool_call["function"]["arguments"]
281+
assert isinstance(actual_arguments, str)
282+
if argument_key is not None:
283+
actual_arguments = json.loads(actual_arguments)
284+
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
273285

274286

275287
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [

0 commit comments

Comments
 (0)