Skip to content

Commit 42b29e1

Browse files
author
ochafik
committed
fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)
1 parent 80c432b commit 42b29e1

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

common/chat.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,17 @@ static std::string apply(
582582
// tmpl_inputs.now = std::chrono::system_clock::now();
583583

584584
minja::chat_template_options tmpl_opts;
585-
tmpl_opts.use_bos_token = false;
586-
tmpl_opts.use_eos_token = false;
587-
588-
return tmpl.apply(tmpl_inputs, tmpl_opts);
585+
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
586+
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
587+
// may be needed inside the template / between messages too.
588+
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
589+
if (string_starts_with(result, tmpl.bos_token())) {
590+
result = result.substr(tmpl.bos_token().size());
591+
}
592+
if (string_ends_with(result, tmpl.eos_token())) {
593+
result = result.substr(0, result.size() - tmpl.eos_token().size());
594+
}
595+
return result;
589596
}
590597

591598
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {

tests/test-chat-template.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ int main(void) {
5757
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
5858
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
5959
/* .expected_output_jinja= */ "",
60-
/* .bos_token= */ "",
60+
/* .bos_token= */ "<s>",
6161
/* .eos_token= */ "</s>",
6262
},
6363
{
@@ -79,8 +79,8 @@ int main(void) {
7979
{
8080
/* .name= */ "mlabonne/AlphaMonarch-7B",
8181
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
82-
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
83-
/* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
82+
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
83+
/* .expected_output_jinja= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
8484
/* .bos_token= */ "<s>",
8585
/* .eos_token= */ "</s>",
8686
},
@@ -94,7 +94,7 @@ int main(void) {
9494
/* .name= */ "OrionStarAI/Orion-14B-Chat",
9595
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
9696
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
97-
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
97+
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: ",
9898
/* .bos_token= */ "",
9999
/* .eos_token= */ "</s>",
100100
},
@@ -323,7 +323,7 @@ int main(void) {
323323
try {
324324
common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token), &common_chat_templates_free);
325325
common_chat_templates_inputs inputs;
326-
inputs.use_jinja = false;
326+
inputs.use_jinja = true;
327327
inputs.messages = messages;
328328
inputs.add_generation_prompt = add_generation_prompt;
329329
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;

0 commit comments

Comments
 (0)