@@ -377,6 +377,59 @@ TEST_CASE("custom template") {
377377 CHECK (res == expected_str);
378378}
379379
380+ TEST_CASE (" generation prompt" ) {
381+ const std::vector<ac::llama::ChatMsg> chat = {
382+ {" system" , " You are a helpful assistant" },
383+ {" user" , " Hello" },
384+ {" assistant" , " Hello, how can I help?" },
385+ {" user" , " I need help with my homework" },
386+ };
387+
388+ SUBCASE (" llama.cpp template" ) {
389+ ac::llama::ChatFormat fmt (" llama3" );
390+
391+ std::string expectedWithoutGenPrompt =
392+ " <|start_header_id|>system<|end_header_id|>\n\n "
393+ " You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n "
394+ " Hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n "
395+ " Hello, how can I help?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n "
396+ " I need help with my homework<|eot_id|>" ;
397+
398+ std::string expectedWithGenPrompt = expectedWithoutGenPrompt + " <|start_header_id|>assistant<|end_header_id|>\n\n " ;
399+
400+ CHECK (fmt.formatChat (chat, true ) == expectedWithGenPrompt);
401+ CHECK (fmt.formatChat (chat, false ) == expectedWithoutGenPrompt);
402+ }
403+
404+ SUBCASE (" custom template" ) {
405+ const std::string chatTemplate =
406+ " {% for message in messages %}"
407+ " {{ '<|' + message['role'] + '|>\\ n' + message['content'] + '<|end|>' + '\\ n' }}"
408+ " {% endfor %}"
409+ " {% if add_generation_prompt %}"
410+ " {{ '<|' + assistant_role + '|>\\ n' }}"
411+ " {% endif %}" ;
412+
413+ ac::llama::ChatFormat fmt{{
414+ .chatTemplate = chatTemplate,
415+ .bosToken = " " ,
416+ .eosToken = " " ,
417+ .roleAssistant = " assistant"
418+ }};
419+
420+ std::string expectedWithoutGenPrompt =
421+ " <|system|>\n You are a helpful assistant<|end|>\n "
422+ " <|user|>\n Hello<|end|>\n "
423+ " <|assistant|>\n Hello, how can I help?<|end|>\n "
424+ " <|user|>\n I need help with my homework<|end|>\n " ;
425+ std::string expectedWithGenPrompt = expectedWithoutGenPrompt + " <|assistant|>\n " ;
426+
427+ CHECK (fmt.formatChat (chat, true ) == expectedWithGenPrompt);
428+ CHECK (fmt.formatChat (chat, false ) == expectedWithoutGenPrompt);
429+ }
430+
431+ }
432+
380433TEST_CASE (" invalid custom template" ) {
381434 std::string bad_template = R"(
382435{% for message in messages %}
0 commit comments