Skip to content

Commit 4a1e8e9

Browse files
author
ochafik
committed
refactor test-chat-handler
1 parent 18d5a1b commit 4a1e8e9

File tree

1 file changed

+18
-64
lines changed

1 file changed

+18
-64
lines changed

tests/test-chat-handler.cpp

Lines changed: 18 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -319,32 +319,10 @@ static void test_template_output_parsers() {
319319
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
320320
std::vector<std::string> end_tokens { "<end_of_turn>" };
321321

322+
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
322323
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
323-
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
324-
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
325-
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
326-
"{\n"
327-
" \"response\": \"Hello, world!\"\n"
328-
"}"));
329-
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
330-
"{\n"
331-
" \"tool_calls\": [\n"
332-
" {\n"
333-
" \"name\": \"special_function\",\n"
334-
" \"arguments\": {\n"
335-
" \"arg1\": 1\n"
336-
" },\n"
337-
" \"id\": \"123456789\"\n"
338-
" }\n"
339-
" ]\n"
340-
"}");
341-
}
342-
{
343-
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
344-
std::vector<std::string> end_tokens { "<|end|>" };
324+
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
345325

346-
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
347-
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
348326
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
349327
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
350328
"{\n"
@@ -368,16 +346,20 @@ static void test_template_output_parsers() {
368346
std::vector<std::string> end_tokens { "</s>" };
369347

370348
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
349+
371350
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
372351
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
373352
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
374353
/* skip_grammar_test= */ true);
375354
}
376355
{
377-
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
356+
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
378357
std::vector<std::string> end_tokens { "<|im_end|>" };
379358

380359
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
360+
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
361+
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
362+
381363
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
382364
test_template(tmpl, end_tokens, tool_call_message, tools,
383365
"<tool_call>\n"
@@ -388,80 +370,50 @@ static void test_template_output_parsers() {
388370
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
389371
"</tool_call>");
390372
}
391-
{
392-
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
393-
std::vector<std::string> end_tokens { "<|im_end|>" };
394-
395-
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
396-
test_template(tmpl, end_tokens, text_message, tools,
397-
"Hello, world!", /* skip_grammar_test= */ true);
398-
test_template(tmpl, end_tokens, tool_call_message, tools,
399-
"<tool_call>\n"
400-
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
401-
"</tool_call>");
402-
}
403-
{
404-
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
405-
std::vector<std::string> end_tokens { "<|im_end|>" };
406-
407-
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
408-
test_template(tmpl, end_tokens, text_message, tools,
409-
"Hello, world!", /* skip_grammar_test= */ true);
410-
test_template(tmpl, end_tokens, tool_call_message, tools,
411-
"<tool_call>\n"
412-
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
413-
"</tool_call>");
414-
}
415373
{
416374
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
417375
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
418376

419377
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
378+
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
379+
420380
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
421381
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
422382
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
423383
test_template(tmpl, end_tokens, python_tool_call_message, tools,
424384
"<|python_tag|>python.call(code=\"print('hey')\")");
425385
test_template(tmpl, end_tokens, tool_call_message, tools,
426386
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
427-
test_template(tmpl, end_tokens, tool_call_message, llama_3_1_tools);
428387
}
429388
{
430-
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
389+
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
431390
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
432391

433-
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
392+
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
393+
434394
test_template(tmpl, end_tokens, text_message, tools,
435395
"Hello, world!", /* skip_grammar_test= */ true);
436396
test_template(tmpl, end_tokens, tool_call_message, tools,
437-
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
397+
"<function=special_function>{\"arg1\": 1}</function>");
438398
}
439399
{
440400
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
441401
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
442402

443403
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
444-
test_template(tmpl, end_tokens, text_message, tools,
445-
"Hello, world!", /* skip_grammar_test= */ true);
446-
test_template(tmpl, end_tokens, tool_call_message, tools,
447-
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
448-
}
449-
{
450-
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
451-
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
452404

453-
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
454405
test_template(tmpl, end_tokens, text_message, tools,
455406
"Hello, world!", /* skip_grammar_test= */ true);
456407
test_template(tmpl, end_tokens, tool_call_message, tools,
457-
"<function=special_function>{\"arg1\": 1}</function>");
408+
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
458409
}
459410
{
460411
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
461412
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
462413

463414
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
464-
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
415+
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
416+
465417
test_template(tmpl, end_tokens, text_message, tools,
466418
"all\n"
467419
"Hello, world!", /* skip_grammar_test= */ true);
@@ -474,6 +426,7 @@ static void test_template_output_parsers() {
474426
std::vector<std::string> end_tokens { "<|eot_id|>" };
475427

476428
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
429+
477430
test_template(tmpl, end_tokens, text_message, tools,
478431
"Hello, world!", /* skip_grammar_test= */ true);
479432
test_template(tmpl, end_tokens, tool_call_message, tools,
@@ -484,6 +437,7 @@ static void test_template_output_parsers() {
484437
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
485438

486439
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
440+
487441
test_template(tmpl, end_tokens, text_message, tools,
488442
"Hello, world!", /* skip_grammar_test= */ true);
489443
test_template(tmpl, end_tokens, tool_call_message, tools,

0 commit comments

Comments
 (0)