Skip to content

Commit 1f5ec59

Browse files
author
ochafik
committed
ensure deepseek r1 thoughts parsed even w/o tool calls
1 parent b6e14a4 commit 1f5ec59

File tree

2 files changed

+91
-49
lines changed

2 files changed

+91
-49
lines changed

common/chat.cpp

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -565,39 +565,41 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
565565

566566
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
567567
common_chat_params data;
568-
data.grammar_lazy = inputs.tool_choice != "required";
569-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
570-
std::vector<std::string> tool_rules;
571-
foreach_function(inputs.tools, [&](const json & tool) {
572-
const auto & function = tool["function"];
573-
std::string name = function["name"];
574-
auto parameters = function["parameters"];
575-
auto args_rule = builder.add_schema(name + "-args", parameters);
576-
tool_rules.push_back(builder.add_rule(name + "-call",
577-
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
578-
"```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
579-
});
580-
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
581-
// so we accept common variants (then it's all constrained)
582-
builder.add_rule("root",
583-
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
584-
"(" +string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
585-
"\"<|tool▁calls▁end|>\""
586-
" space");
587-
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
588-
data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
589-
data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
590-
data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
591-
data.grammar_triggers.push_back({"<|tool▁call▁begin|>", /* .at_start = */ false});
592-
data.preserved_tokens = {
593-
"<think>",
594-
"</think>",
595-
"<|tool▁sep|>",
596-
"<|tool▁calls▁end|",
597-
"<|tool▁call▁begin|>",
598-
"<|tool▁call▁end|>",
599-
};
600-
}, grammar_options);
568+
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
569+
data.grammar_lazy = inputs.tool_choice != "required";
570+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
571+
std::vector<std::string> tool_rules;
572+
foreach_function(inputs.tools, [&](const json & tool) {
573+
const auto & function = tool["function"];
574+
std::string name = function["name"];
575+
auto parameters = function["parameters"];
576+
auto args_rule = builder.add_schema(name + "-args", parameters);
577+
tool_rules.push_back(builder.add_rule(name + "-call",
578+
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
579+
"```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
580+
});
581+
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
582+
// so we accept common variants (then it's all constrained)
583+
builder.add_rule("root",
584+
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
585+
"(" +string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
586+
"\"<|tool▁calls▁end|>\""
587+
" space");
588+
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
589+
data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
590+
data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
591+
data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
592+
data.grammar_triggers.push_back({"<|tool▁call▁begin|>", /* .at_start = */ false});
593+
data.preserved_tokens = {
594+
"<think>",
595+
"</think>",
596+
"<|tool▁sep|>",
597+
"<|tool▁calls▁end|",
598+
"<|tool▁call▁begin|>",
599+
"<|tool▁call▁end|>",
600+
};
601+
}, grammar_options);
602+
}
601603
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
602604

603605
// Hacks to fix the official (broken) prompt.
@@ -638,7 +640,7 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input)
638640
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
639641
msg.tool_calls = std::move(msg2.tool_calls);
640642
} else {
641-
msg.content = rest;
643+
msg.content = std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
642644
}
643645
} else {
644646
msg.content = input;
@@ -970,6 +972,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
970972
// Firefunction v2 requires datetime and functions in the context, even w/o tools.
971973
return common_chat_params_init_firefunction_v2(tmpl, inputs);
972974
}
975+
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
976+
return common_chat_params_init_deepseek_r1(tmpl, inputs);
977+
}
973978

974979
if (!has_tools) {
975980
return common_chat_params_init_without_tools(tmpl, inputs);
@@ -986,9 +991,6 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
986991
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
987992
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
988993
}
989-
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
990-
return common_chat_params_init_deepseek_r1(tmpl, inputs);
991-
}
992994
if (src.find("[TOOL_CALLS]") != std::string::npos) {
993995
return common_chat_params_init_mistral_nemo(tmpl, inputs);
994996
}

examples/server/tests/unit/test_tool_call.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -345,20 +345,20 @@ def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str
345345

346346
@pytest.mark.slow
347347
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
348-
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
349-
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
350-
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
351-
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
352-
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
353-
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
354-
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
355-
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
356-
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
348+
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
349+
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
350+
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
351+
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
352+
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
353+
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
354+
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
355+
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
356+
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
357357
358358
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
359-
("[\\s\\S\\r\\n]*?\\b0\\.55644242476$", 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
360-
("[\\s\\S\\r\\n]*?which equals 0\\.5\\.", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
361-
("**Answer:** 0\\.25\\b", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
359+
("[\\s\\S\\r\\n]*?\\b0\\.55644242476$", 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
360+
("[\\s\\S\\r\\n]*?which equals 0\\.5\\.", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
361+
("[\\s\\S\\r\\n]*?\\*\\*Answer:\\*\\* 0\\.25\\b", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
362362
])
363363
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
364364
global server
@@ -435,6 +435,46 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
435435
f'Expected something like "The y coordinate is 0.56.", got {content}'
436436

437437

438+
@pytest.mark.slow
439+
@pytest.mark.parametrize("n_predict,expect_content,expect_thoughts,hf_repo,template_override", [
440+
(128, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
441+
(1024, "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
442+
(1024, "To find the sum of.*", "First, I need to add the tens place.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
443+
])
444+
def test_thoughts(n_predict: int, expect_content: str | None, expect_thoughts: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
445+
global server
446+
server.n_slots = 1
447+
server.jinja = True
448+
server.n_ctx = 8192 * 2
449+
server.n_predict = n_predict
450+
server.model_hf_repo = hf_repo
451+
server.model_hf_file = None
452+
if isinstance(template_override, tuple):
453+
(template_hf_repo, template_variant) = template_override
454+
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
455+
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
456+
elif isinstance(template_override, str):
457+
server.chat_template = template_override
458+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
459+
res = server.make_request("POST", "/chat/completions", data={
460+
"max_tokens": n_predict,
461+
"messages": [
462+
{"role": "user", "content": "What's the sum of 102 and 7?"},
463+
]
464+
}, timeout=TIMEOUT_HTTP_REQUEST)
465+
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
466+
choice = res.body["choices"][0]
467+
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
468+
469+
content = choice["message"].get("content")
470+
if expect_content is not None:
471+
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
472+
473+
thoughts = choice["message"].get("thoughts")
474+
if expect_thoughts is not None:
475+
assert re.match(expect_thoughts, thoughts), f'Expected {expect_thoughts}, got {thoughts}'
476+
477+
438478
@pytest.mark.slow
439479
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
440480
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),

0 commit comments

Comments
 (0)