Skip to content

Commit dae5e73

Browse files
authored
only remove double bos/eos if added
1 parent ee3a9fc commit dae5e73

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

common/chat.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
126126
typedef minja::chat_template common_chat_template;
127127

128128
struct common_chat_templates {
129+
bool add_bos;
130+
bool add_eos;
129131
bool has_explicit_template; // Model had builtin template or template overridde was specified.
130132
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
131133
std::unique_ptr<common_chat_template> template_tool_use;
@@ -143,6 +145,8 @@ struct templates_params {
143145
bool enable_thinking = true;
144146
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
145147
json extra_context;
148+
bool add_bos;
149+
bool add_eos;
146150
};
147151

148152
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -445,6 +449,8 @@ std::string common_chat_format_single(
445449

446450
common_chat_templates_inputs inputs;
447451
inputs.use_jinja = use_jinja;
452+
inputs.add_bos = tmpls->add_bos;
453+
inputs.add_eos = tmpls->add_eos;
448454

449455
std::string fmt_past_msg;
450456
if (!past_msg.empty()) {
@@ -469,6 +475,8 @@ std::string common_chat_format_single(
469475
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
470476
common_chat_templates_inputs inputs;
471477
inputs.use_jinja = use_jinja;
478+
inputs.add_bos = tmpls->add_bos;
479+
inputs.add_eos = tmpls->add_eos;
472480
auto add_simple_msg = [&](auto role, auto content) {
473481
common_chat_msg msg;
474482
msg.role = role;
@@ -546,6 +554,8 @@ common_chat_templates_ptr common_chat_templates_init(
546554
}
547555
std::string token_bos = bos_token_override;
548556
std::string token_eos = eos_token_override;
557+
bool add_bos = false;
558+
bool add_eos = false;
549559
if (model) {
550560
const auto * vocab = llama_model_get_vocab(model);
551561
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
@@ -560,9 +570,13 @@ common_chat_templates_ptr common_chat_templates_init(
560570
};
561571
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
562572
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
573+
add_bos = llama_vocab_get_add_bos(vocab);
574+
add_eos = llama_vocab_get_add_eos(vocab);
563575
}
564576
common_chat_templates_ptr tmpls(new common_chat_templates());
565577
tmpls->has_explicit_template = has_explicit_template;
578+
tmpls->add_bos = add_bos;
579+
tmpls->add_eos = add_eos;
566580
try {
567581
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
568582
} catch (const std::exception & e) {
@@ -748,10 +762,10 @@ static std::string apply(
748762
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
749763
// may be needed inside the template / between messages too.
750764
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
751-
if (string_starts_with(result, tmpl.bos_token())) {
765+
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
752766
result = result.substr(tmpl.bos_token().size());
753767
}
754-
if (string_ends_with(result, tmpl.eos_token())) {
768+
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
755769
result = result.substr(0, result.size() - tmpl.eos_token().size());
756770
}
757771
return result;
@@ -1731,6 +1745,8 @@ static common_chat_params common_chat_templates_apply_jinja(
17311745
params.enable_thinking = inputs.enable_thinking;
17321746
params.grammar = inputs.grammar;
17331747
params.now = inputs.now;
1748+
params.add_bos = inputs.add_bos;
1749+
params.add_eos = inputs.add_eos;
17341750

17351751
params.extra_context = json::object();
17361752
for (auto el : inputs.chat_template_kwargs) {

common/chat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ struct common_chat_templates_inputs {
127127
bool enable_thinking = true;
128128
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
129129
std::map<std::string, std::string> chat_template_kwargs;
130+
bool add_bos = false;
131+
bool add_eos = false;
130132
};
131133

132134
struct common_chat_params {

0 commit comments

Comments
 (0)