Skip to content

Commit f324a3b

Browse files
authored
chat : only remove double bos/eos if added (#15086)
* only remove double bos/eos if added * fix tests
1 parent be42642 commit f324a3b

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

common/chat.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
126126
typedef minja::chat_template common_chat_template;
127127

128128
struct common_chat_templates {
129+
bool add_bos;
130+
bool add_eos;
129131
bool has_explicit_template; // Model had builtin template or template overridde was specified.
130132
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
131133
std::unique_ptr<common_chat_template> template_tool_use;
@@ -143,6 +145,8 @@ struct templates_params {
143145
bool enable_thinking = true;
144146
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
145147
json extra_context;
148+
bool add_bos;
149+
bool add_eos;
146150
};
147151

148152
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -445,6 +449,8 @@ std::string common_chat_format_single(
445449

446450
common_chat_templates_inputs inputs;
447451
inputs.use_jinja = use_jinja;
452+
inputs.add_bos = tmpls->add_bos;
453+
inputs.add_eos = tmpls->add_eos;
448454

449455
std::string fmt_past_msg;
450456
if (!past_msg.empty()) {
@@ -469,6 +475,8 @@ std::string common_chat_format_single(
469475
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
470476
common_chat_templates_inputs inputs;
471477
inputs.use_jinja = use_jinja;
478+
inputs.add_bos = tmpls->add_bos;
479+
inputs.add_eos = tmpls->add_eos;
472480
auto add_simple_msg = [&](auto role, auto content) {
473481
common_chat_msg msg;
474482
msg.role = role;
@@ -546,6 +554,8 @@ common_chat_templates_ptr common_chat_templates_init(
546554
}
547555
std::string token_bos = bos_token_override;
548556
std::string token_eos = eos_token_override;
557+
bool add_bos = false;
558+
bool add_eos = false;
549559
if (model) {
550560
const auto * vocab = llama_model_get_vocab(model);
551561
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
@@ -560,9 +570,13 @@ common_chat_templates_ptr common_chat_templates_init(
560570
};
561571
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
562572
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
573+
add_bos = llama_vocab_get_add_bos(vocab);
574+
add_eos = llama_vocab_get_add_eos(vocab);
563575
}
564576
common_chat_templates_ptr tmpls(new common_chat_templates());
565577
tmpls->has_explicit_template = has_explicit_template;
578+
tmpls->add_bos = add_bos;
579+
tmpls->add_eos = add_eos;
566580
try {
567581
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
568582
} catch (const std::exception & e) {
@@ -748,10 +762,10 @@ static std::string apply(
748762
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
749763
// may be needed inside the template / between messages too.
750764
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
751-
if (string_starts_with(result, tmpl.bos_token())) {
765+
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
752766
result = result.substr(tmpl.bos_token().size());
753767
}
754-
if (string_ends_with(result, tmpl.eos_token())) {
768+
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
755769
result = result.substr(0, result.size() - tmpl.eos_token().size());
756770
}
757771
return result;
@@ -1731,6 +1745,8 @@ static common_chat_params common_chat_templates_apply_jinja(
17311745
params.enable_thinking = inputs.enable_thinking;
17321746
params.grammar = inputs.grammar;
17331747
params.now = inputs.now;
1748+
params.add_bos = inputs.add_bos;
1749+
params.add_eos = inputs.add_eos;
17341750

17351751
params.extra_context = json::object();
17361752
for (auto el : inputs.chat_template_kwargs) {

common/chat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ struct common_chat_templates_inputs {
127127
bool enable_thinking = true;
128128
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
129129
std::map<std::string, std::string> chat_template_kwargs;
130+
bool add_bos = false;
131+
bool add_eos = false;
130132
};
131133

132134
struct common_chat_params {

tests/test-chat-template.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ int main(void) {
6161
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
6262
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
6363
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
64-
/* .expected_output_jinja= */ "",
64+
/* .expected_output_jinja= */ "<s>[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
6565
/* .bos_token= */ "<s>",
6666
/* .eos_token= */ "</s>",
6767
},
@@ -85,7 +85,7 @@ int main(void) {
8585
/* .name= */ "mlabonne/AlphaMonarch-7B",
8686
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
8787
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
88-
/* .expected_output_jinja= */ "",
88+
/* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
8989
/* .bos_token= */ "<s>",
9090
/* .eos_token= */ "</s>",
9191
},
@@ -99,7 +99,7 @@ int main(void) {
9999
/* .name= */ "OrionStarAI/Orion-14B-Chat",
100100
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
101101
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
102-
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: ",
102+
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
103103
/* .bos_token= */ "",
104104
/* .eos_token= */ "</s>",
105105
},

0 commit comments

Comments
 (0)