Skip to content

Commit ba27e98

Browse files
author
ochafik
committed
Unify llama 3.x chat handling again (allow {"type": "function", "name": ... prefix)
1 parent 7b5e080 commit ba27e98

File tree

5 files changed

+73
-102
lines changed

5 files changed

+73
-102
lines changed

common/chat-handler.cpp

Lines changed: 39 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ static void expect_tool_parameters(const std::string & name, const json & parame
344344
}
345345
}
346346

347-
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
347+
static common_chat_data common_chat_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params, bool allow_python_tag_builtin_tools) {
348348
auto builtin_tools = json::array();
349349
common_chat_data data;
350350
data.grammar_lazy = params.tool_choice != "required";
@@ -379,24 +379,31 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
379379
return true;
380380
};
381381

382+
auto has_function = false;
382383
foreach_function(params.tools, [&](const json & tool) {
383384
const auto & function = tool["function"];
384385
std::string name = function["name"];
385386
auto parameters = function["parameters"];
386387

387388
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
388-
if (handle_builtin_tool(name, parameters)) {
389+
if (allow_python_tag_builtin_tools && handle_builtin_tool(name, parameters)) {
389390
return;
390391
}
391392
builder.resolve_refs(parameters);
392393
tool_rules.push_back(
393394
builder.add_rule(
394395
name + "-call",
395-
"\"{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
396+
"\"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
397+
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
396398
builder.add_schema(name + "-args", parameters) +
397399
" \"}\""));
398400
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
401+
has_function = true;
399402
});
403+
if (has_function) {
404+
data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
405+
data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
406+
}
400407
if (!builtin_tools.empty()) {
401408
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
402409
}
@@ -407,79 +414,44 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
407414
{"tools_in_user_message", false},
408415
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
409416
});
410-
data.format = "llama 3.1 tool calls";
411-
data.parser = [params](const std::string & input) -> common_chat_msg {
412-
static std::regex function_regex("\\{\"name\": \"([^\"]+)\", \"parameters\": ");
417+
data.format = std::string("llama 3.x tool calls") + (allow_python_tag_builtin_tools ? " (w/ builtin tools)" : "");
418+
data.parser = [params, builtin_tools, allow_python_tag_builtin_tools](const std::string & input) -> common_chat_msg {
419+
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
413420
static std::regex close_regex("\\}");
414421
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
415422

416-
std::smatch match;
417-
if (std::regex_match(input, match, builtin_call_regex)) {
418-
auto name = match[1].str();
419-
auto raw_args = match[2].str();
423+
if (allow_python_tag_builtin_tools && !builtin_tools.empty()) {
424+
std::smatch match;
425+
if (std::regex_match(input, match, builtin_call_regex)) {
426+
auto name = match[1].str();
427+
auto raw_args = match[2].str();
420428

421-
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
422-
auto it_eq = raw_args.find('=');
423-
auto arg_name = raw_args.substr(0, it_eq);
424-
auto arg_value_str = raw_args.substr(it_eq + 1);
425-
auto arg_value = json::parse(arg_value_str);
429+
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
430+
auto it_eq = raw_args.find('=');
431+
auto arg_name = raw_args.substr(0, it_eq);
432+
auto arg_value_str = raw_args.substr(it_eq + 1);
433+
auto arg_value = json::parse(arg_value_str);
426434

427-
return {
428-
/* .role = */ "assistant",
429-
/* .content = */ match.prefix().str(),
430-
/* .tool_calls = */ {
431-
{
432-
/* .name = */ match[1],
433-
/* .arguments = */ (json {
434-
{arg_name, arg_value},
435-
}).dump(),
436-
/* .id = */ "",
435+
return {
436+
/* .role = */ "assistant",
437+
/* .content = */ match.prefix().str(),
438+
/* .tool_calls = */ {
439+
{
440+
/* .name = */ match[1],
441+
/* .arguments = */ (json {
442+
{arg_name, arg_value},
443+
}).dump(),
444+
/* .id = */ "",
445+
},
437446
},
438-
},
439-
};
447+
};
448+
}
440449
}
441450
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
442451
};
443452
return data;
444453
}
445454

446-
static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
447-
common_chat_data data;
448-
449-
data.grammar_lazy = params.tool_choice != "required";
450-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
451-
std::vector<std::string> tool_rules;
452-
453-
foreach_function(params.tools, [&](const json & tool) {
454-
const auto & function = tool["function"];
455-
std::string name = function["name"];
456-
auto parameters = function["parameters"];
457-
builder.resolve_refs(parameters);
458-
tool_rules.push_back(
459-
builder.add_rule(
460-
name + "-call",
461-
"\"{\" "
462-
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
463-
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
464-
builder.add_schema(name + "-args", parameters) +
465-
" \"}\""));
466-
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
467-
});
468-
469-
builder.add_rule("root", string_join(tool_rules, " | "));
470-
}, grammar_options);
471-
data.additional_stops.push_back("<|eom_id|>");
472-
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt, {});
473-
data.format = "llama 3.2 tool calls";
474-
data.parser = [params](const std::string & input) {
475-
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
476-
static std::regex close_regex("\\}");
477-
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
478-
return res;
479-
};
480-
return data;
481-
}
482-
483455
static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
484456
common_chat_data data;
485457
data.grammar_lazy = params.tool_choice != "required";
@@ -559,8 +531,8 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
559531
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
560532
common_chat_data data;
561533

562-
data.grammar_lazy = params.tool_choice != "required";
563534
if (!params.tools.is_null() && !params.tools.empty()) {
535+
data.grammar_lazy = params.tool_choice != "required";
564536
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
565537
std::vector<std::string> first_tool_rules;
566538
std::vector<std::string> subsequent_tool_rules;
@@ -806,13 +778,8 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
806778
return common_chat_init_functionary_v3_1_llama_3_1_tool_call(tmpl, params);
807779
}
808780
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
809-
auto uses_python_tag = src.find("<|python_tag|>") != std::string::npos;
810-
811-
if (uses_python_tag) {
812-
return common_chat_init_llama_3_1_python_tag_tool_calls(tmpl, params);
813-
} else {
814-
return common_chat_init_llama_3_2_tool_calls(tmpl, params);
815-
}
781+
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
782+
return common_chat_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
816783
}
817784
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
818785
return common_chat_init_deepseek_r1_tool_call(tmpl, params);

examples/server/server.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3800,6 +3800,8 @@ int main(int argc, char ** argv) {
38003800
/* .grammar = */ json_value(data, "grammar", std::string("")),
38013801
});
38023802
LOG_INF("Chat format: %s\n", chat_data.format.c_str());
3803+
LOG_DBG("Prompt: %s\n", chat_data.prompt.get<std::string>().c_str());
3804+
LOG_DBG("Grammar: %s\n", chat_data.grammar.c_str());
38033805
if (data.contains("grammar")) {
38043806
if (!chat_data.grammar.empty()) {
38053807
throw std::runtime_error("Cannot provide grammar and tools");
@@ -3841,11 +3843,11 @@ int main(int argc, char ** argv) {
38413843
for (const auto & trigger : chat_data.grammar_triggers) {
38423844
auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
38433845
if (ids.size() == 1) {
3844-
LOG_INF("Grammar trigger token: %s (%d)\n", trigger.word.c_str(), ids[0]);
3846+
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
38453847
task.params.sampling.grammar_trigger_tokens.push_back(ids[0]);
38463848
continue;
38473849
}
3848-
LOG_INF("Grammar trigger word: %s\n", trigger.word.c_str());
3850+
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
38493851
task.params.sampling.grammar_trigger_words.push_back(trigger);
38503852
}
38513853
task.params.antiprompt = chat_data.additional_stops;
@@ -4021,6 +4023,7 @@ int main(int argc, char ** argv) {
40214023
};
40224024

40234025
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
4026+
LOG_DBG("request: %s\n", req.body.c_str());
40244027
if (ctx_server.params_base.embedding) {
40254028
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
40264029
return;

examples/server/tests/unit/test_tool_call.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def create_server():
5858
"required":["location"]
5959
}
6060
}
61-
}# TODO: fix this crash
61+
}
6262

6363

6464
def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
@@ -132,8 +132,8 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
132132

133133
@pytest.mark.slow
134134
@pytest.mark.parametrize("tool,argument_key,hf_repo,hf_file,template_override", [
135-
(TEST_TOOL, "success", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
136-
(PYTHON_TOOL, "code", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
135+
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
136+
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
137137
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
138138
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
139139
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
@@ -231,7 +231,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
231231
@pytest.mark.slow
232232
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
233233
# TODO: fix this crash
234-
# ("meetkai-functionary-medium-v3.2", 256, [], None),
234+
("meetkai-functionary-medium-v3.2", 256, [], None),
235235
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
236236
("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
237237
("meetkai-functionary-medium-v3.1", 256, [], None),
@@ -247,9 +247,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
247247

248248
@pytest.mark.slow
249249
@pytest.mark.parametrize("hf_repo,hf_file,template_override", [
250-
# TODO: fix these
251-
# ("lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
252-
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
250+
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
253251
("bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
254252
("bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
255253
("bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
@@ -259,6 +257,8 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
259257
("bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai/functionary-medium-v3.2", None)),
260258
("bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama/Llama-3.2-3B-Instruct", None)),
261259
("bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama/Llama-3.2-3B-Instruct", None)),
260+
# TODO: fix this (times out)
261+
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
262262
])
263263
def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
264264
global server
@@ -276,7 +276,6 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
276276
res = server.make_request("POST", "/chat/completions", data={
277277
"max_tokens": 256,
278278
"messages": [
279-
# {"role": "system", "content": "Use tools as appropriate."},
280279
{"role": "user", "content": "What is the weather in Istanbul?"},
281280
],
282281
"tools": [WEATHER_TOOL],
@@ -295,21 +294,21 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
295294

296295

297296
@pytest.mark.slow
298-
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
299-
# TODO: fix these
300-
# ('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
301-
# (None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
302-
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
297+
@pytest.mark.parametrize("expected_arguments_override,hf_repo,hf_file,template_override", [
298+
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
299+
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
303300
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
301+
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
304302
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
305303
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
306-
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
307-
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
308304
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
305+
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
309306
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
310307
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q4_K_M.gguf", None),
308+
# TODO: fix this (times out)
309+
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF", "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", None),
311310
])
312-
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
311+
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
313312
global server
314313
server.n_slots = 1
315314
server.jinja = True
@@ -319,15 +318,14 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
319318
server.model_hf_file = hf_file
320319
if template_override:
321320
(template_hf_repo, template_variant) = template_override
322-
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
321+
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
323322
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
324323
server.start(timeout_seconds=15*60)
325324
res = server.make_request("POST", "/chat/completions", data={
326325
"max_tokens": 256,
327326
"messages": [
328327
{"role": "system", "content": "You are a coding assistant."},
329328
{"role": "user", "content": "say hello world with python"},
330-
# {"role": "user", "content": "Print a hello world message with python"},
331329
],
332330
"tools": [PYTHON_TOOL],
333331
# Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test.
@@ -342,8 +340,8 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
342340
tool_call = tool_calls[0]
343341
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
344342
actual_arguments = tool_call["function"]["arguments"]
345-
if expected_arguments is not None:
346-
assert actual_arguments == expected_arguments
343+
if expected_arguments_override is not None:
344+
assert actual_arguments == expected_arguments_override
347345
else:
348346
actual_arguments = json.loads(actual_arguments)
349347
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"

0 commit comments

Comments
 (0)