Skip to content

Commit 7a61eca

Browse files
author
ochafik
committed
fix thinking model's initial trigger (take 2) + test qwq's template
1 parent e2cef66 commit 7a61eca

File tree

5 files changed

+157
-19
lines changed

5 files changed

+157
-19
lines changed

common/chat.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -899,13 +899,15 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
899899
schema["maxItems"] = 1;
900900
}
901901
builder.add_rule("root",
902-
std::string(data.thinking_forced_open ? "\"<|END_THINKING|>\" space " : "") +
902+
std::string(data.thinking_forced_open ? "( \"<|END_THINKING|>\" space )? " : "") +
903903
"\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
904904
});
905905
data.grammar_triggers.push_back({
906906
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
907-
std::string(data.thinking_forced_open ? "[\\s\\S]*?(<\\|END_THINKING\\|>)" : "(?:<\\|START_THINKING\\|>[\\s\\S]*?(<\\|END_THINKING\\|>))?") +
908-
"\\s*(<\\|START_ACTION\\|>)[\\s\\S]*"
907+
// If thinking_forced_open, then we capture the </think> tag in the grammar,
908+
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
909+
std::string(data.thinking_forced_open ? "[\\s\\S]*?(<\\|END_THINKING\\|>\\s*)" : "(?:<\\|START_THINKING\\|>[\\s\\S]*?<\\|END_THINKING\\|>\\s*)?") +
910+
"(<\\|START_ACTION\\|>)[\\s\\S]*"
909911
});
910912
data.preserved_tokens = {
911913
"<|START_ACTION|>",
@@ -1140,15 +1142,17 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
11401142
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
11411143
// so we accept common variants (then it's all constrained)
11421144
builder.add_rule("root",
1143-
std::string(data.thinking_forced_open ? "\"</think>\" space " : "") +
1145+
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
11441146
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) "
11451147
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
11461148
"\"<|tool▁calls▁end|>\""
11471149
" space");
11481150
data.grammar_triggers.push_back({
11491151
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1150-
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>)" : "(?:<think>[\\s\\S]*?(</think>))?") +
1151-
"\\s*(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*"
1152+
// If thinking_forced_open, then we capture the </think> tag in the grammar,
1153+
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
1154+
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") +
1155+
"(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*"
11521156
});
11531157
data.preserved_tokens = {
11541158
"<think>",
@@ -1430,13 +1434,15 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
14301434
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
14311435
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
14321436
builder.add_rule("root",
1433-
std::string(data.thinking_forced_open ? "\"</think>\" space " : "") +
1437+
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
14341438
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
14351439
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
14361440
data.grammar_triggers.push_back({
14371441
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1438-
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>)" : "(?:<think>[\\s\\S]*?(</think>))?") + (
1439-
"\\s*(<tool_call>"
1442+
// If thinking_forced_open, then we capture the </think> tag in the grammar,
1443+
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
1444+
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
1445+
"(<tool_call>"
14401446
"|<function"
14411447
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\""
14421448
")[\\s\\S]*"
@@ -1490,12 +1496,13 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
14901496

14911497
auto start = builder.pos();
14921498
if (auto res = builder.try_find_regex(open_regex)) {
1493-
if (res->groups[0].begin != start && res->groups[4].empty() && res->groups[5].empty()) {
1494-
// The only syntax we allow after the very start is <function=...> or <function name=...>
1499+
if (res->groups[0].begin != start && builder.str(res->groups[2]) != "<tool_call>" && res->groups[4].empty() && res->groups[5].empty()) {
1500+
// The only syntaxes we allow after the very start are <tool_call>, <function=...> or <function name=...>
1501+
builder.move_to(start);
14951502
builder.add_content(builder.consume_rest());
14961503
return;
14971504
}
1498-
GGML_ASSERT(res->prelude.empty()); // matching at_start
1505+
builder.add_content(res->prelude);
14991506

15001507
const auto & block_start = res->groups[1];
15011508
std::string block_end = block_start.empty() ? "" : "```";
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
{%- if tools %}
2+
{{- '<|im_start|>system\n' }}
3+
{%- if messages[0]['role'] == 'system' %}
4+
{{- messages[0]['content'] }}
5+
{%- else %}
6+
{{- '' }}
7+
{%- endif %}
8+
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
9+
{%- for tool in tools %}
10+
{{- "\n" }}
11+
{{- tool | tojson }}
12+
{%- endfor %}
13+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
14+
{%- else %}
15+
{%- if messages[0]['role'] == 'system' %}
16+
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
17+
{%- endif %}
18+
{%- endif %}
19+
{%- for message in messages %}
20+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
21+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
22+
{%- elif message.role == "assistant" and not message.tool_calls %}
23+
{%- set content = message.content %}
24+
{%- if not loop.last %}
25+
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
26+
{%- endif %}
27+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
28+
{%- elif message.role == "assistant" %}
29+
{%- set content = message.content %}
30+
{%- if not loop.last %}
31+
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
32+
{%- endif %}
33+
{{- '<|im_start|>' + message.role }}
34+
{%- if message.content %}
35+
{{- '\n' + content }}
36+
{%- endif %}
37+
{%- for tool_call in message.tool_calls %}
38+
{%- if tool_call.function is defined %}
39+
{%- set tool_call = tool_call.function %}
40+
{%- endif %}
41+
{{- '\n<tool_call>\n{"name": "' }}
42+
{{- tool_call.name }}
43+
{{- '", "arguments": ' }}
44+
{{- tool_call.arguments | tojson }}
45+
{{- '}\n</tool_call>' }}
46+
{%- endfor %}
47+
{{- '<|im_end|>\n' }}
48+
{%- elif message.role == "tool" %}
49+
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
50+
{{- '<|im_start|>user' }}
51+
{%- endif %}
52+
{{- '\n<tool_response>\n' }}
53+
{{- message.content }}
54+
{{- '\n</tool_response>' }}
55+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
56+
{{- '<|im_end|>\n' }}
57+
{%- endif %}
58+
{%- endif %}
59+
{%- endfor %}
60+
{%- if add_generation_prompt %}
61+
{{- '<|im_start|>assistant\n<think>\n' }}
62+
{%- endif %}

models/templates/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ These templates can be updated with the following commands:
1919
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
2020
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
2121
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
22+
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
2223
```

src/llama-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
11801180
// get from the first matched capturing group to the end of the string
11811181
size_t start = std::string::npos;
11821182
for (auto i = 1u; i < match.size(); i++) {
1183-
if (!match.str(i).empty()) {
1183+
if (match.length(i) > 0) {
11841184
start = match.position(i);
11851185
break;
11861186
}

tests/test-chat.cpp

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,17 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
327327
{
328328
const auto & pattern = trigger.value;
329329
if (std::regex_match(constrained, match, std::regex(pattern))) {
330-
pos = match.position(1);
330+
auto mpos = std::string::npos;
331+
for (size_t i = 1; i < match.size(); ++i) {
332+
if (match[i].length() > 0) {
333+
mpos = match.position(i);
334+
break;
335+
}
336+
}
337+
if (mpos == std::string::npos) {
338+
mpos = match.position(0);
339+
}
340+
pos = mpos;
331341
}
332342
break;
333343
}
@@ -469,6 +479,15 @@ const common_chat_msg message_assist_call {
469479
/* .tool_name = */ "",
470480
/* .tool_call_id = */ "",
471481
};
482+
const common_chat_msg message_assist_call_content {
483+
"assistant",
484+
"Hello, world!\nWhat's up?",
485+
/* .content_parts = */ {},
486+
tool_calls,
487+
/* .reasoning_content = */ "",
488+
/* .tool_name = */ "",
489+
/* .tool_call_id = */ "",
490+
};
472491
const common_chat_msg message_assist_thoughts_no_content {
473492
"assistant",
474493
"",
@@ -722,8 +741,11 @@ static void test_template_output_parsers() {
722741
auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
723742
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
724743

725-
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
726-
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
744+
for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
745+
auto params = common_chat_templates_apply(tmpls.get(), inputs);
746+
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, params.format);
747+
assert_equals(false, params.thinking_forced_open);
748+
}
727749

728750
assert_msg_equals(message_assist,
729751
common_chat_parse(
@@ -877,11 +899,25 @@ static void test_template_output_parsers() {
877899
tmpls.get(), end_tokens, message_assist_call_id, tools,
878900
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
879901
}
902+
{
903+
auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
904+
std::vector<std::string> end_tokens{ "<|im_end|>" };
905+
906+
for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
907+
auto params = common_chat_templates_apply(tmpls.get(), inputs);
908+
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, params.format);
909+
assert_equals(true, params.thinking_forced_open);
910+
}
911+
}
880912
{
881913
auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
882914
std::vector<std::string> end_tokens{ "<|im_end|>" };
883915

884-
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
916+
for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
917+
auto params = common_chat_templates_apply(tmpls.get(), inputs);
918+
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, params.format);
919+
assert_equals(false, params.thinking_forced_open);
920+
}
885921
assert_equals(
886922
COMMON_CHAT_FORMAT_HERMES_2_PRO,
887923
common_chat_templates_apply(
@@ -937,6 +973,13 @@ static void test_template_output_parsers() {
937973
"</tool_call>",
938974
/* is_partial= */ false,
939975
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
976+
assert_msg_equals(message_assist_call_content,
977+
common_chat_parse(
978+
"Hello, world!\nWhat's up?<tool_call>\n"
979+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
980+
"</tool_call>",
981+
/* is_partial= */ false,
982+
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
940983
assert_msg_equals(
941984
message_assist_call,
942985
common_chat_parse(
@@ -1066,6 +1109,27 @@ static void test_template_output_parsers() {
10661109
/* is_partial= */ false,
10671110
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
10681111

1112+
assert_msg_equals(
1113+
{
1114+
/* .role = */ "assistant",
1115+
"This is not a tool call:\n"
1116+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
1117+
/* .content_parts = */ {},
1118+
/* .tool_calls = */ {},
1119+
/* .reasoning_content = */ "",
1120+
/* .tool_name = */ "",
1121+
/* .tool_call_id = */ "",
1122+
},
1123+
common_chat_parse(
1124+
"This is not a tool call:\n"
1125+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
1126+
/* is_partial= */ false,
1127+
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
1128+
assert_msg_equals(message_assist,
1129+
common_chat_parse(
1130+
"Hello, world!\nWhat's up?",
1131+
/* is_partial= */ false,
1132+
{COMMON_CHAT_FORMAT_HERMES_2_PRO}));
10691133
assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
10701134
common_chat_parse(
10711135
"<think>I'm thinking</think>Hello, world!\nWhat's up?",
@@ -1162,7 +1226,7 @@ static void test_template_output_parsers() {
11621226
is_partial,
11631227
{COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1}));
11641228
}
1165-
1229+
11661230
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
11671231
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
11681232
"<function=special_function>{\"arg1\": 1}</function>");
@@ -1243,7 +1307,11 @@ static void test_template_output_parsers() {
12431307
auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
12441308
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
12451309

1246-
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1310+
for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
1311+
auto params = common_chat_templates_apply(tmpls.get(), inputs);
1312+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, params.format);
1313+
assert_equals(true, params.thinking_forced_open);
1314+
}
12471315

12481316
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
12491317
test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);

0 commit comments

Comments
 (0)