Skip to content

Commit add9124

Browse files
author
ochafik
committed
fix test-chat-handler grammar tests
1 parent 118f799 commit add9124

File tree

3 files changed

+82
-59
lines changed

3 files changed

+82
-59
lines changed

common/chat-handler.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
363363

364364
static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const common_chat_template & tmpl, const struct common_chat_params & params) {
365365
fprintf(stderr, "[%s]\n", __func__);
366+
// TODO: get from request body.
366367
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
367368
common_chat_data data;
368369

@@ -377,10 +378,16 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
377378
tool_rules.push_back(
378379
builder.add_rule(
379380
name + "-call",
380-
"\"<|python_tag|>\" \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
381+
"\"{\" "
382+
// " ( \"\\\"type\\\": \\\"function\\\", \" | space ) "
383+
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
381384
builder.add_schema(name + "-args", parameters) +
382385
" \"}\""));
386+
if (params.tool_choice != "required") {
387+
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
388+
}
383389
});
390+
tool_rules.push_back(builder.add_rule("builtin-tool-call", "\"<|python_tag|>\" .*"));
384391
if (params.tool_choice != "required") {
385392
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
386393
}
@@ -391,11 +398,27 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
391398
{"builtin_tools", builtin_tools},
392399
});
393400
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
394-
static std::regex function_regex("<\\|python_tag\\|>\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
401+
static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": ");
395402
static std::regex close_regex("\\}");
403+
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\((.*)\)");
404+
405+
std::smatch match;
406+
if (std::regex_match(input, match, builtin_call_regex)) {
407+
auto arguments = json::parse("[" + match[2].str() + "]");
408+
return {
409+
/* .role = */ "assistant",
410+
/* .content = */ match.prefix().str(),
411+
/* .tool_calls = */ {
412+
{
413+
/* .name = */ match[1],
414+
/* .arguments = */ arguments.dump(),
415+
/* .id = */ "",
416+
},
417+
},
418+
};
419+
}
396420
return parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
397421
});
398-
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
399422
return data;
400423
}
401424

@@ -435,7 +458,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
435458
auto res = parse_json_tool_calls(params.tools, input, std::nullopt, function_regex, close_regex, /* check_names= */ true);
436459
return res;
437460
});
438-
fprintf(stderr, "Grammar: %s\n", data.grammar.c_str());
439461
return data;
440462
}
441463

@@ -590,9 +612,8 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
590612
} else if (type != "string") {
591613
throw std::runtime_error("Invalid type in python tool: " + type.dump());
592614
}
593-
} else {
594-
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
595615
}
616+
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
596617
});
597618
if (has_raw_python) {
598619
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));

common/chat-template.hpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class chat_template {
129129
bool supports_tools() const { return supports_tools_; }
130130
bool supports_tool_calls() const { return supports_tool_calls_; }
131131
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
132+
bool requires_object_arguments() const { return requires_object_arguments_; }
132133

133134
std::string apply(
134135
const nlohmann::ordered_json & messages,
@@ -201,12 +202,14 @@ class chat_template {
201202
for (auto & tool_call : message.at("tool_calls")) {
202203
if (tool_call["type"] == "function") {
203204
auto & function = tool_call.at("function");
204-
std::string arguments = function.at("arguments");
205-
try {
206-
function["arguments"] = json::parse(arguments);
207-
} catch (const std::exception & ecvt) {
208-
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
209-
function["arguments"] = arguments;
205+
auto & arguments = function.at("arguments");
206+
if (arguments.is_string()) {
207+
try {
208+
arguments = json::parse(arguments.get<std::string>());
209+
} catch (const std::exception & ecvt) {
210+
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
211+
arguments = arguments;
212+
}
210213
}
211214
}
212215
}

tests/test-chat-handler.cpp

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -72,32 +72,17 @@ static std::string dump(const json & j) {
7272
return minja::Value(j).dump(-1, /* to_json= */ true);
7373
}
7474

75-
static void assert_msg_equals(const common_chat_msg & result, const std::string & expected_content, const json & expected_tool_calls) {
76-
assert_equals(expected_content, result.content);
77-
auto tool_calls = json::array();
78-
for (const auto & tc : result.tool_calls) {
79-
auto arguments = tc.arguments;
80-
try {
81-
arguments = dump(json::parse(arguments));
82-
} catch (const std::exception & e) {
83-
// ignore
84-
}
85-
auto tool_call = json {
86-
{"type", "function"},
87-
{"function", {
88-
{"arguments", arguments},
89-
{"name", tc.name},
90-
}},
91-
};
92-
if (!tc.id.empty()) {
93-
tool_call["id"] = tc.id;
94-
}
95-
tool_calls.push_back(tool_call);
75+
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
76+
assert_equals(expected.role, actual.role);
77+
assert_equals(expected.content, actual.content);
78+
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
79+
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
80+
const auto & expected_tool_call = expected.tool_calls[i];
81+
const auto & actual_tool_call = actual.tool_calls[i];
82+
assert_equals(expected_tool_call.name, actual_tool_call.name);
83+
assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
84+
assert_equals(expected_tool_call.id, actual_tool_call.id);
9685
}
97-
// Reparse / dump w/ non-ordered JSON variant.
98-
auto expected = nlohmann::json::parse(expected_tool_calls.dump()).dump();
99-
auto actual = nlohmann::json::parse(tool_calls.dump()).dump();
100-
assert_equals(expected, actual);
10186
}
10287

10388
const auto special_function_tool = json::parse(R"({
@@ -373,7 +358,19 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
373358

374359
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
375360
// auto tool_call_style = common_tool_call_style_detect(tmpl);
376-
auto & tool_calls = tool_calling_message.at("tool_calls");
361+
common_chat_msg expected_msg {
362+
"assistant",
363+
"",
364+
{},
365+
};
366+
for (const auto & tc : tool_calling_message.at("tool_calls")) {
367+
const auto & arguments = tc.at("function").at("arguments");
368+
expected_msg.tool_calls.push_back({
369+
tc.at("function").at("name").get<std::string>(),
370+
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
371+
tc.contains("id") ? tc.at("id").get<std::string>() : "",
372+
});
373+
}
377374

378375
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
379376
// get the diff and try and parse it w/ the grammar.
@@ -398,12 +395,12 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
398395
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
399396

400397
const auto msg = chat_data.parser->parse_final(full_delta);
401-
assert_msg_equals(msg, "", tool_calls);
398+
assert_msg_equals(expected_msg, msg);
402399

403400
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
404401
{"role", "assistant"},
405402
{"content", {}},
406-
{"tool_calls", tool_calls}
403+
{"tool_calls", tool_calling_message.at("tool_calls")}
407404
}, tools);
408405
if (!match_string(content_less_delta, grammar.get())) {
409406
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
@@ -433,7 +430,9 @@ static void test_grammars() {
433430
{"type", "function"},
434431
{"function", {
435432
{"name", "python"},
436-
{"arguments", "print('hey')"}
433+
{"arguments", {
434+
{"code", "print('hey')"},
435+
}},
437436
}},
438437
}}}
439438
};
@@ -442,12 +441,12 @@ static void test_grammars() {
442441
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
443442
test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
444443
}
445-
// {
446-
// const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
447-
// // assert_equals(tmpl.requires_object_arguments_, true);
448-
// test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
449-
// test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
450-
// }
444+
{
445+
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
446+
// assert_equals(tmpl.requires_object_arguments_, true);
447+
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
448+
test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
449+
}
451450
{
452451
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
453452
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
@@ -456,22 +455,22 @@ static void test_grammars() {
456455
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
457456
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
458457
}
459-
// {
460-
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
461-
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
462-
// }
463-
// {
464-
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
465-
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
466-
// }
458+
{
459+
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
460+
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
461+
}
462+
{
463+
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
464+
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
465+
}
467466
{
468467
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
469468
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
470469
}
471-
// {
472-
// const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
473-
// test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
474-
// }
470+
{
471+
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
472+
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
473+
}
475474
{
476475
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
477476
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);

0 commit comments

Comments
 (0)