Skip to content

Commit a270b2d

Browse files
committed
chore: add test for the generation prompt, ref #65
1 parent 2c8d4cf commit a270b2d

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

test/t-ChatFormat.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,59 @@ TEST_CASE("custom template") {
377377
CHECK(res == expected_str);
378378
}
379379

380+
TEST_CASE("generation prompt") {
381+
const std::vector<ac::llama::ChatMsg> chat = {
382+
{"system", "You are a helpful assistant"},
383+
{"user", "Hello"},
384+
{"assistant", "Hello, how can I help?"},
385+
{"user", "I need help with my homework"},
386+
};
387+
388+
SUBCASE("llama.cpp template") {
389+
ac::llama::ChatFormat fmt("llama3");
390+
391+
std::string expectedWithoutGenPrompt =
392+
"<|start_header_id|>system<|end_header_id|>\n\n"
393+
"You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
394+
"Hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
395+
"Hello, how can I help?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
396+
"I need help with my homework<|eot_id|>";
397+
398+
std::string expectedWithGenPrompt = expectedWithoutGenPrompt + "<|start_header_id|>assistant<|end_header_id|>\n\n";
399+
400+
CHECK(fmt.formatChat(chat, true) == expectedWithGenPrompt);
401+
CHECK(fmt.formatChat(chat, false) == expectedWithoutGenPrompt);
402+
}
403+
404+
SUBCASE("custom template") {
405+
const std::string chatTemplate =
406+
"{% for message in messages %}"
407+
"{{ '<|' + message['role'] + '|>\\n' + message['content'] + '<|end|>' + '\\n' }}"
408+
"{% endfor %}"
409+
"{% if add_generation_prompt %}"
410+
"{{ '<|' + assistant_role + '|>\\n' }}"
411+
"{% endif %}";
412+
413+
ac::llama::ChatFormat fmt{{
414+
.chatTemplate = chatTemplate,
415+
.bosToken = "",
416+
.eosToken = "",
417+
.roleAssistant = "assistant"
418+
}};
419+
420+
std::string expectedWithoutGenPrompt =
421+
"<|system|>\nYou are a helpful assistant<|end|>\n"
422+
"<|user|>\nHello<|end|>\n"
423+
"<|assistant|>\nHello, how can I help?<|end|>\n"
424+
"<|user|>\nI need help with my homework<|end|>\n";
425+
std::string expectedWithGenPrompt = expectedWithoutGenPrompt + "<|assistant|>\n";
426+
427+
CHECK(fmt.formatChat(chat, true) == expectedWithGenPrompt);
428+
CHECK(fmt.formatChat(chat, false) == expectedWithoutGenPrompt);
429+
}
430+
431+
}
432+
380433
TEST_CASE("invalid custom template") {
381434
std::string bad_template = R"(
382435
{% for message in messages %}

0 commit comments

Comments
 (0)