Skip to content

Commit ef9efc9

Browse files
author
ochafik
committed
Fix Llama 3.1 (incl. constrained builtin tools e.g. <|python_tag|>foo.call(arg=vallue))
1 parent 2d607f1 commit ef9efc9

File tree

3 files changed

+115
-31
lines changed

3 files changed

+115
-31
lines changed

common/chat-handler.cpp

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ static void foreach_function(const json & tools, const std::function<void(const
207207
}
208208

209209
static common_chat_data common_chat_init_generic_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
210-
fprintf(stderr, "[%s]\n", __func__);
211210
common_chat_data data;
212211

213212
auto tool_call_schemas = json::array();
@@ -318,7 +317,6 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
318317
}
319318

320319
static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
321-
fprintf(stderr, "[%s]\n", __func__);
322320
common_chat_data data;
323321
data.grammar_lazy = params.tool_choice != "required";
324322
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -358,25 +356,71 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
358356
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
359357
data.format = "mistral nemo tool calls";
360358
data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
361-
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
362-
});
359+
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
360+
});
363361
return data;
364362
}
365363

364+
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
365+
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
366+
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
367+
}
368+
const auto & parameters_properties = parameters.at("properties");
369+
const auto & parameters_required = parameters.at("required");
370+
for (const auto & prop : expected_properties) {
371+
if (!parameters_properties.contains(prop)) {
372+
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
373+
}
374+
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
375+
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
376+
}
377+
}
378+
if (parameters_properties.size() != expected_properties.size()) {
379+
throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
380+
}
381+
}
382+
366383
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
367-
fprintf(stderr, "[%s]\n", __func__);
368-
// TODO: get from request body.
369-
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
384+
auto builtin_tools = json::array();
370385
common_chat_data data;
371-
372386
data.grammar_lazy = params.tool_choice != "required";
373387
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
374388
std::vector<std::string> tool_rules;
375389

390+
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
391+
if (name == "wolfram_alpha") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
392+
expect_tool_parameters(name, parameters, {"query"});
393+
} else if (name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
394+
expect_tool_parameters(name, parameters, {"query"});
395+
} else if (name == "python" || name == "code_interpreter") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
396+
expect_tool_parameters(name, parameters, {"code"});
397+
} else {
398+
return false;
399+
}
400+
401+
std::vector<std::string> kvs;
402+
for (const auto & [key, value] : parameters.at("properties").items()) {
403+
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
404+
}
405+
406+
tool_rules.push_back(
407+
builder.add_rule(
408+
name + "-call",
409+
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
410+
builtin_tools.push_back(name);
411+
412+
return true;
413+
};
414+
376415
foreach_function(params.tools, [&](const json & tool) {
377416
const auto & function = tool["function"];
378417
std::string name = function["name"];
379418
auto parameters = function["parameters"];
419+
420+
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
421+
if (handle_builtin_tool(name, parameters)) {
422+
return;
423+
}
380424
builder.resolve_refs(parameters);
381425
tool_rules.push_back(
382426
builder.add_rule(
@@ -388,30 +432,42 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
388432
" \"}\""));
389433
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
390434
});
391-
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
392-
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
435+
if (!builtin_tools.empty()) {
436+
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
437+
}
393438
builder.add_rule("root", string_join(tool_rules, " | "));
394439
}, grammar_options);
395440
data.additional_stops.push_back("<|eom_id|>");
396441
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {
397-
{"builtin_tools", builtin_tools},
442+
{"tools_in_user_message", false},
443+
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
398444
});
399445
data.format = "llama 3.1 tool calls";
400446
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
401447
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
402448
static std::regex close_regex("\\}");
403-
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)");
449+
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
404450

405451
std::smatch match;
406452
if (std::regex_match(input, match, builtin_call_regex)) {
407-
auto arguments = json::parse("[" + match[2].str() + "]");
453+
auto name = match[1].str();
454+
auto raw_args = match[2].str();
455+
456+
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
457+
auto it_eq = raw_args.find('=');
458+
auto arg_name = raw_args.substr(0, it_eq);
459+
auto arg_value_str = raw_args.substr(it_eq + 1);
460+
auto arg_value = json::parse(arg_value_str);
461+
408462
return {
409463
/* .role = */ "assistant",
410464
/* .content = */ match.prefix().str(),
411465
/* .tool_calls = */ {
412466
{
413467
/* .name = */ match[1],
414-
/* .arguments = */ arguments.dump(),
468+
/* .arguments = */ (json {
469+
{arg_name, arg_value},
470+
}).dump(),
415471
/* .id = */ "",
416472
},
417473
},
@@ -423,7 +479,6 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
423479
}
424480

425481
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
426-
fprintf(stderr, "[%s]\n", __func__);
427482
common_chat_data data;
428483

429484
data.grammar_lazy = params.tool_choice != "required";
@@ -462,7 +517,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
462517
}
463518

464519
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
465-
fprintf(stderr, "[%s]\n", __func__);
466520
common_chat_data data;
467521
data.grammar_lazy = params.tool_choice != "required";
468522
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -490,7 +544,6 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
490544
}
491545

492546
static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
493-
fprintf(stderr, "[%s]\n", __func__);
494547
common_chat_data data;
495548
data.grammar_lazy = params.tool_choice != "required";
496549
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -529,7 +582,6 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
529582
}
530583

531584
static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
532-
fprintf(stderr, "[%s]\n", __func__);
533585
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
534586
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
535587
common_chat_data data;
@@ -574,7 +626,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
574626
}
575627

576628
static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
577-
fprintf(stderr, "[%s]\n", __func__);
578629
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
579630
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
580631
common_chat_data data;
@@ -651,7 +702,6 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
651702
}
652703

653704
static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
654-
fprintf(stderr, "[%s]\n", __func__);
655705
common_chat_data data;
656706
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
657707
data.grammar_lazy = params.tool_choice != "required";
@@ -705,9 +755,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
705755
if (!parse_json(it, end, call)) {
706756
throw std::runtime_error("Failed to parse json tool call");
707757
}
758+
const auto & arguments = call["arguments"];
708759
result.tool_calls.push_back({
709760
call["name"],
710-
call["arguments"].dump(),
761+
arguments.dump(),
762+
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
711763
/* id= */ "",
712764
});
713765
rit = {it, end, middle_pattern};
@@ -734,7 +786,6 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
734786
}
735787

736788
static common_chat_data common_chat_init_without_tools(const common_chat_template & tmpl, const struct common_chat_params & params) {
737-
fprintf(stderr, "[%s]\n", __func__);
738789
common_chat_data data;
739790
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
740791
data.format = "content-only";

examples/server/tests/unit/test_tool_call.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def create_server():
6363

6464

6565
@pytest.mark.parametrize("template_name,tool,argument_key", [
66+
("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
67+
("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
6668
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
6769
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
6870
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
@@ -78,8 +80,6 @@ def create_server():
7880
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
7981
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
8082
# TODO: fix these
81-
# ("meta-llama-Meta-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
82-
# ("meta-llama-Meta-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
8383
])
8484
def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
8585
n_predict = 512
@@ -118,6 +118,8 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
118118

119119
@pytest.mark.slow
120120
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
121+
(TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
122+
(PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
121123
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
122124
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
123125
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
@@ -139,8 +141,6 @@ def test_completion_with_required_tool_tiny(template_name: str, tool: dict, argu
139141
# TODO: fix these
140142
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
141143
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
142-
# (TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
143-
# (PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
144144
])
145145
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
146146
n_predict = 512
@@ -218,6 +218,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
218218

219219
@pytest.mark.slow
220220
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
221+
("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
221222
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
222223
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
223224
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
@@ -229,7 +230,6 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
229230
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
230231
# TODO: fix these
231232
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
232-
# ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
233233
])
234234
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
235235
global server
@@ -267,6 +267,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
267267

268268
@pytest.mark.slow
269269
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
270+
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
270271
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
271272
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
272273
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
@@ -277,7 +278,6 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
277278
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
278279
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
279280
# TODO: fix these
280-
# ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
281281
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
282282
])
283283
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):

tests/test-chat-handler.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,25 @@ const auto python_tool = json::parse(R"({
119119
}
120120
}
121121
})");
122+
const auto code_interpreter_tool = json::parse(R"({
123+
"type": "function",
124+
"function": {
125+
"name": "code_interpreter",
126+
"description": "an ipython interpreter",
127+
"parameters": {
128+
"type": "object",
129+
"properties": {
130+
"code": {
131+
"type": "string",
132+
"description": "Python code to execute."
133+
}
134+
},
135+
"required": ["code"]
136+
}
137+
}
138+
})");
122139
const json tools = {special_function_tool, python_tool};
140+
const json llama_3_1_tools = {special_function_tool, code_interpreter_tool};
123141

124142
// static void test_parsing() {
125143
// json request = {
@@ -427,6 +445,19 @@ static void test_grammars() {
427445
}},
428446
}}}
429447
};
448+
auto code_interpreter_tool_call_message = json {
449+
{"role", "assistant"},
450+
{"content", {}},
451+
{"tool_calls", json {{
452+
{"type", "function"},
453+
{"function", {
454+
{"name", "code_interpreter"},
455+
{"arguments", {
456+
{"code", "print('hey')"},
457+
}},
458+
}},
459+
}}}
460+
};
430461

431462

432463
common_chat_params no_tools_params;
@@ -494,10 +525,12 @@ static void test_grammars() {
494525
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
495526
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
496527

497-
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
498-
test_template(tmpl, end_tokens, text_message, tools);
528+
// assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
529+
// test_template(tmpl, end_tokens, text_message, tools);
530+
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools);
531+
test_template(tmpl, end_tokens, python_tool_call_message, tools);
499532
test_template(tmpl, end_tokens, tool_call_message, tools);
500-
test_template(tmpl, end_tokens, python_tool_call_message, tools);
533+
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
501534
}
502535
{
503536
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");

0 commit comments

Comments
 (0)