Skip to content

Commit 2b24569

Browse files
author
ochafik
committed
Add cli mode to test-chat to generate template summaries markdown
1 parent 84bc083 commit 2b24569

File tree

1 file changed

+72
-36
lines changed

1 file changed

+72
-36
lines changed

tests/test-chat.cpp

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
/*
2+
Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
3+
4+
Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
5+
e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
6+
7+
cmake -B build && cmake --build build --parallel && \
8+
./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
9+
10+
*/
111
#include "chat.hpp"
212
#include "chat-template.hpp"
313
#include "llama-grammar.h"
@@ -44,7 +54,7 @@ static void assert_equals(const T & expected, const T & actual) {
4454
}
4555

4656
static std::string read_file(const std::string &path) {
47-
std::cout << "# Reading: " << path << std::endl << std::flush;
57+
std::cerr << "# Reading: " << path << std::endl << std::flush;
4858
std::ifstream fs(path, std::ios_base::binary);
4959
if (!fs.is_open()) {
5060
fs = std::ifstream("../" + path, std::ios_base::binary);
@@ -168,13 +178,15 @@ struct delta_data {
168178
common_chat_parser parser;
169179
};
170180

171-
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
181+
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice) {
172182
common_chat_inputs inputs;
173183
inputs.parallel_tool_calls = true;
174184
inputs.messages = json::array();
175185
inputs.messages.push_back(user_message);
176186
inputs.tools = tools;
187+
inputs.tool_choice = tool_choice;
177188
auto params_prefix = common_chat_params_init(tmpl, inputs);
189+
178190
inputs.messages.push_back(delta_message);
179191
inputs.add_generation_prompt = false;
180192
auto params_full = common_chat_params_init(tmpl, inputs);
@@ -220,7 +232,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
220232
};
221233

222234
for (const auto & tool_choice : json({"auto", "required"})) {
223-
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools);
235+
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
224236
if (!expected_delta.empty()) {
225237
assert_equals(expected_delta, data.delta);
226238
}
@@ -248,6 +260,10 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
248260
}
249261
}
250262

263+
static std::string describe(const common_chat_template & tmpl, const common_chat_inputs & params) {
264+
return common_chat_params_init(tmpl, params).format;
265+
}
266+
251267
static void test_template_output_parsers() {
252268
auto text_message = json {
253269
{"role", "assistant"},
@@ -295,29 +311,25 @@ static void test_template_output_parsers() {
295311
};
296312

297313

298-
common_chat_inputs no_tools_params;
299-
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
314+
common_chat_inputs inputs_no_tools;
315+
inputs_no_tools.messages = {{{"role", "user"}, {"content", "Hey"}}};
300316

301-
common_chat_inputs tools_params = no_tools_params;
302-
tools_params.tools = json::array();
303-
tools_params.tools.push_back(special_function_tool);
304-
305-
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
306-
return common_chat_params_init(tmpl, params).format;
307-
};
317+
common_chat_inputs inputs_tools = inputs_no_tools;
318+
inputs_tools.tools = json::array();
319+
inputs_tools.tools.push_back(special_function_tool);
308320

309321
{
310322
const common_chat_template tmpl(read_file(
311323
"models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
312324
std::vector<std::string> end_tokens { "<end_of_turn>" };
313325

314-
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
315-
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
316-
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file(
317-
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
326+
assert_equals(std::string("content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
327+
assert_equals(std::string("generic tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
328+
assert_equals(std::string("generic tool calls"), common_chat_params_init(common_chat_template(read_file(
329+
"models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
318330

319331
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
320-
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
332+
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, inputs_tools).parser(
321333
"{\n"
322334
" \"response\": \"Hello, world!\"\n"
323335
"}"));
@@ -339,7 +351,7 @@ static void test_template_output_parsers() {
339351
"models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
340352
std::vector<std::string> end_tokens { "</s>" };
341353

342-
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
354+
assert_equals(std::string("mistral nemo tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
343355

344356
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
345357
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
@@ -351,11 +363,11 @@ static void test_template_output_parsers() {
351363
"models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
352364
std::vector<std::string> end_tokens { "<|im_end|>" };
353365

354-
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
355-
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
356-
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
357-
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file(
358-
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
366+
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
367+
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
368+
"models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), inputs_tools).format);
369+
assert_equals(std::string("hermes 2 pro tool calls"), common_chat_params_init(common_chat_template(read_file(
370+
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
359371

360372
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
361373
test_template(tmpl, end_tokens, tool_call_message, tools,
@@ -372,9 +384,9 @@ static void test_template_output_parsers() {
372384
"models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
373385
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
374386

375-
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params));
376-
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file(
377-
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
387+
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(tmpl, inputs_tools).format);
388+
assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), common_chat_params_init(common_chat_template(read_file(
389+
"models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), inputs_tools).format);
378390

379391
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
380392
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
@@ -389,7 +401,7 @@ static void test_template_output_parsers() {
389401
"models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
390402
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
391403

392-
assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params));
404+
assert_equals(std::string("llama 3.x tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
393405

394406
test_template(tmpl, end_tokens, text_message, tools,
395407
"Hello, world!", /* skip_grammar_test= */ true);
@@ -401,7 +413,7 @@ static void test_template_output_parsers() {
401413
"models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
402414
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
403415

404-
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
416+
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
405417

406418
test_template(tmpl, end_tokens, text_message, tools,
407419
"Hello, world!", /* skip_grammar_test= */ true);
@@ -413,8 +425,8 @@ static void test_template_output_parsers() {
413425
"models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
414426
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
415427

416-
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
417-
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
428+
assert_equals(std::string("functionary v3.2 content-only"), common_chat_params_init(tmpl, inputs_no_tools).format);
429+
assert_equals(std::string("functionary v3.2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
418430

419431
test_template(tmpl, end_tokens, text_message, tools,
420432
"all\n"
@@ -428,7 +440,7 @@ static void test_template_output_parsers() {
428440
"models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
429441
std::vector<std::string> end_tokens { "<|eot_id|>" };
430442

431-
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
443+
assert_equals(std::string("firefunction v2 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
432444

433445
test_template(tmpl, end_tokens, text_message, tools,
434446
"Hello, world!", /* skip_grammar_test= */ true);
@@ -440,7 +452,7 @@ static void test_template_output_parsers() {
440452
"models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
441453
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
442454

443-
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
455+
assert_equals(std::string("deepseek r1 tool calls"), common_chat_params_init(tmpl, inputs_tools).format);
444456

445457
test_template(tmpl, end_tokens, text_message, tools,
446458
"Hello, world!", /* skip_grammar_test= */ true);
@@ -452,9 +464,33 @@ static void test_template_output_parsers() {
452464
}
453465
}
454466

455-
int main() {
456-
test_template_output_parsers();
457-
458-
std::cout << "\n[tool-call] All tests passed!" << std::endl;
467+
int main(int argc, char **argv) {
468+
#ifndef _WIN32
469+
if (argc > 1) {
470+
common_chat_inputs inputs;
471+
inputs.messages = {{{"role", "user"}, {"content", "Hey"}}};
472+
inputs.tools = json::array({special_function_tool});
473+
474+
std::cout << "| Template | Format |\n";
475+
std::cout << "|----------|--------|\n";
476+
477+
for (int i = 1; i < argc; i++) {
478+
std::string path = argv[i];
479+
if (path.rfind(".jinja") != path.size() - 6) {
480+
std::cerr << "Skipping non-jinja file: " << path << std::endl;
481+
continue;
482+
}
483+
common_chat_template tmpl(read_file(path), "", "");
484+
auto parts = string_split(path, "/");
485+
auto name = parts[parts.size() - 1];
486+
std::cout << "| " << name << " | " << common_chat_params_init(tmpl, inputs).format << " |\n";
487+
}
488+
}
489+
else
490+
#endif
491+
{
492+
test_template_output_parsers();
493+
std::cout << "\n[chat] All tests passed!" << std::endl;
494+
}
459495
return 0;
460496
}

0 commit comments

Comments
 (0)