Skip to content

Commit 118f799

Browse files
author
ochafik
committed
DeepSeek-R1: implement grammar constraints
1 parent 92ac336 commit 118f799

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

common/chat-handler.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,26 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
443443
fprintf(stderr, "[%s]\n", __func__);
444444
common_chat_data data;
445445
data.grammar = "root ::= .*";
446+
// data.grammar = "root ::= .*";
447+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
448+
std::vector<std::string> tool_rules;
449+
foreach_function(params.tools, [&](const json & tool) {
450+
const auto & function = tool["function"];
451+
std::string name = function["name"];
452+
auto parameters = function["parameters"];
453+
auto args_rule = builder.add_schema(name + "-args", parameters);
454+
tool_rules.push_back(builder.add_rule(name + "-call",
455+
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
456+
});
457+
if (params.tool_choice != "required") {
458+
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
459+
}
460+
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (params.parallel_tool_calls ? "*" : "") + " space");
461+
}, grammar_options);
446462
data.prompt = tmpl.apply(params.messages, params.tools.empty() ? json() : params.tools, params.add_generation_prompt);
447463
data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
448464
static std::regex trigger_regex("<|tool▁calls▁begin|>");
449-
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^<]+)\n```json\n");
465+
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
450466
static std::regex close_regex("```<|tool▁call▁end|>");
451467
return parse_json_tool_calls(params.tools, input, trigger_regex, function_regex, close_regex, /* check_names= */ true);
452468
});

tests/test-chat-handler.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
353353
throw std::runtime_error("Full message does not start with prefix");
354354
}
355355

356+
if (full == prefix) {
357+
throw std::runtime_error("Full message is the same as the prefix");
358+
}
359+
356360
auto delta = full.substr(prefix.size());
357361

358362
// Strip end tokens
@@ -398,7 +402,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
398402

399403
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
400404
{"role", "assistant"},
401-
{"content", ""},
405+
{"content", {}},
402406
{"tool_calls", tool_calls}
403407
}, tools);
404408
if (!match_string(content_less_delta, grammar.get())) {
@@ -490,7 +494,7 @@ static void test_grammars() {
490494
}
491495
{
492496
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
493-
test_template(tmpl, {}, tool_call_message, tools);
497+
test_template(tmpl, { "<|end▁of▁sentence|>" }, tool_call_message, tools);
494498
}
495499
}
496500

0 commit comments

Comments
 (0)