Skip to content

Commit 537e157

Browse files
swordowleongyi
authored andcommitted
mod: chat template add support for deepseek-r1-distill series
1 parent e675276 commit 537e157

File tree

3 files changed

+64
-34
lines changed

3 files changed

+64
-34
lines changed

src/llama-chat.cpp

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,40 @@ static std::string trim(const std::string & str) {
2626
}
2727

2828
static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
29-
{ "chatml", LLM_CHAT_TEMPLATE_CHATML },
30-
{ "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 },
31-
{ "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS },
32-
{ "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS },
33-
{ "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP },
34-
{ "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 },
35-
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
36-
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
37-
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
38-
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
39-
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
40-
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
41-
{ "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR },
42-
{ "monarch", LLM_CHAT_TEMPLATE_MONARCH },
43-
{ "gemma", LLM_CHAT_TEMPLATE_GEMMA },
44-
{ "orion", LLM_CHAT_TEMPLATE_ORION },
45-
{ "openchat", LLM_CHAT_TEMPLATE_OPENCHAT },
46-
{ "vicuna", LLM_CHAT_TEMPLATE_VICUNA },
47-
{ "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA },
48-
{ "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
49-
{ "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
50-
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
51-
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
52-
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
53-
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
54-
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
55-
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
56-
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
57-
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
58-
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
59-
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
60-
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
61-
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
29+
{ "chatml", LLM_CHAT_TEMPLATE_CHATML },
30+
{ "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 },
31+
{ "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS },
32+
{ "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS },
33+
{ "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP },
34+
{ "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 },
35+
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
36+
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
37+
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
38+
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
39+
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
40+
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
41+
{ "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR },
42+
{ "monarch", LLM_CHAT_TEMPLATE_MONARCH },
43+
{ "gemma", LLM_CHAT_TEMPLATE_GEMMA },
44+
{ "orion", LLM_CHAT_TEMPLATE_ORION },
45+
{ "openchat", LLM_CHAT_TEMPLATE_OPENCHAT },
46+
{ "vicuna", LLM_CHAT_TEMPLATE_VICUNA },
47+
{ "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA },
48+
{ "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
49+
{ "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
50+
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
51+
{ "deepseek-r1-distill", LLM_CHAT_TEMPLATE_DEEPSEEK_R1_DISTILL },
52+
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
53+
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
54+
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
55+
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
56+
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
57+
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
58+
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
59+
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
60+
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
61+
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
62+
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
6263
};
6364

6465
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -154,6 +155,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
154155
return LLM_CHAT_TEMPLATE_MINICPM;
155156
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
156157
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
158+
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>")) && tmpl_contains(LU8("<|Assistant|><think>\\n"))) {
159+
return LLM_CHAT_TEMPLATE_DEEPSEEK_R1_DISTILL;
157160
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
158161
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
159162
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
@@ -492,7 +495,24 @@ int32_t llm_chat_apply_template(
492495
if (add_ass) {
493496
ss << LU8("<|Assistant|>");
494497
}
495-
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
498+
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_R1_DISTILL) {
499+
// DeepSeek-R1-Distill
500+
ss << "<|begin▁of▁sentence|>";
501+
for (auto message : chat) {
502+
std::string role(message->role);
503+
if (role == "system") {
504+
ss << message->content;
505+
} else if (role == "user") {
506+
ss << LU8("<|User|>") << message->content;
507+
} else if (role == "assistant") {
508+
ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>");
509+
}
510+
}
511+
if (add_ass) {
512+
ss << LU8("<|Assistant|>");
513+
}
514+
}
515+
else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
496516
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
497517
// EXAONE-3.0-7.8B-Instruct
498518
for (auto message : chat) {

src/llama-chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ enum llm_chat_template {
2727
LLM_CHAT_TEMPLATE_DEEPSEEK,
2828
LLM_CHAT_TEMPLATE_DEEPSEEK_2,
2929
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
30+
LLM_CHAT_TEMPLATE_DEEPSEEK_R1_DISTILL,
3031
LLM_CHAT_TEMPLATE_COMMAND_R,
3132
LLM_CHAT_TEMPLATE_LLAMA_3,
3233
LLM_CHAT_TEMPLATE_CHATGML_3,

tests/test-chat-template.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,15 @@ int main(void) {
270270
/* .bos_token= */ "",
271271
/* .eos_token= */ "",
272272
},
273+
{
274+
/* .name= */ "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
275+
/* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}",
276+
/* .expected_output= */ "<|begin▁of▁sentence|>You are a helpful assistant<|User|>Hello<|Assistant|>Hi there<|end▁of▁sentence|><|User|>Who are you<|Assistant|> I am an assistant <|end▁of▁sentence|><|User|>Another question<|Assistant|><think>\n",
277+
/* .expected_output_jinja= */ "<|begin▁of▁sentence|>You are a helpful assistant<|User|>Hello<|Assistant|>Hi there<|end▁of▁sentence|><|User|>Who are you<|Assistant|> I am an assistant <|end▁of▁sentence|><|User|>Another question<|Assistant|><think>\n",
278+
/* .bos_token= */ "<|begin▁of▁sentence|>",
279+
/* .eos_token= */ "<|end▁of▁sentence|>",
280+
/* .supported_with_jinja= */ true, // Requires additional_special_tokens as extra context
281+
},
273282
};
274283
std::vector<char> formatted_chat(1024);
275284
int32_t res;

0 commit comments

Comments
 (0)