Skip to content

Commit f55c434

Browse files
committed
Changed system message logic and added tests for all 4
1 parent dbbde92 commit f55c434

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

src/llama.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21867,21 +21867,21 @@ static int32_t llama_chat_apply_template_internal(
2186721867
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
2186821868
std::string leading_space = (tmpl == "mistral-v1" ? " " : "");
2186921869
std::string trailing_space = (tmpl != "mistral-v3-tekken" ? " " : "");
21870-
std::string system_message = "";
21870+
bool is_inside_turn = false;
2187121871
for (auto message : chat) {
21872+
if (!is_inside_turn) {
21873+
ss << leading_space << "[INST]" << trailing_space;
21874+
is_inside_turn = true;
21875+
}
2187221876
std::string role(message->role);
2187321877
std::string content = trim(message->content);
2187421878
if (role == "system") {
21875-
system_message = content;
21879+
ss << system_message << "\n\n";
2187621880
} else if (role == "user") {
21877-
ss << leading_space << "[INST]" << trailing_space;
21878-
if (!system_message.empty()) {
21879-
ss << system_message << "\n\n";
21880-
system_message = "";
21881-
}
2188221881
ss << content << leading_space << "[/INST]";
2188321882
} else {
2188421883
ss << trailing_space << content << "</s>";
21884+
is_inside_turn = false;
2188521885
}
2188621886
}
2188721887
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {

tests/test-chat-template.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ int main(void) {
154154
return output;
155155
};
156156
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
157+
assert(fmt_sys("mistral-v1") == " [INST] You are a helpful assistant\n\n");
158+
assert(fmt_sys("mistral-v2") == "[INST] You are a helpful assistant\n\n");
159+
assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n");
160+
assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n");
157161
assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
158162
assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
159163
assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
@@ -173,6 +177,10 @@ int main(void) {
173177
return output;
174178
};
175179
assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
180+
assert(fmt_single("mistral-v1") == " [INST] How are you [/INST]");
181+
assert(fmt_single("mistral-v2") == "[INST] How are you[/INST]");
182+
assert(fmt_single("mistral-v3") == "[INST] How are you[/INST]");
183+
assert(fmt_single("mistral-v3-tekken") == "[INST]How are you[/INST]");
176184
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
177185
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
178186
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");

0 commit comments

Comments
 (0)