Skip to content

Commit e1f526c

Browse files
fix: track supports_enable_thinking in chat templates and propagate it through common_chat_params for consistent <think> handling
- Added a supports_enable_thinking field to common_chat_params, populate it during template rendering, and reuse it when deciding whether the generic <think> fallback should run - Updated common_chat_templates_support_enable_thinking to consult the tracked capability and expanded the chat template tests to assert the flag for templates that do and do not react to enable_thinking - Updated chat template tests to assert the guarded fallback behaviour and to cover templates that conditionally open <think> blocks.
1 parent 0869085 commit e1f526c

File tree

3 files changed

+83
-17
lines changed

3 files changed

+83
-17
lines changed

common/chat.cpp

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ struct templates_params {
149149
bool add_bos;
150150
bool add_eos;
151151
bool is_inference = true;
152+
bool supports_enable_thinking = false;
152153
};
153154

154155
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -171,10 +172,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
171172
msg.content = "test";
172173
dummy_inputs.messages = {msg};
173174
dummy_inputs.enable_thinking = false;
174-
const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
175-
dummy_inputs.enable_thinking = true;
176-
const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
177-
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
175+
const auto rendered = common_chat_templates_apply(chat_templates, dummy_inputs);
176+
return rendered.supports_enable_thinking;
178177
}
179178

180179
template <>
@@ -827,6 +826,7 @@ static std::string apply(
827826

828827
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
829828
common_chat_params data;
829+
data.supports_enable_thinking = inputs.supports_enable_thinking;
830830

831831
auto tool_call_schemas = json::array();
832832
foreach_function(inputs.tools, [&](const json & tool) {
@@ -944,6 +944,7 @@ static void common_chat_parse_generic(common_chat_msg_parser & builder) {
944944

945945
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
946946
common_chat_params data;
947+
data.supports_enable_thinking = inputs.supports_enable_thinking;
947948
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
948949
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
949950
auto schemas = json::array();
@@ -989,6 +990,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
989990

990991
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
991992
common_chat_params data;
993+
data.supports_enable_thinking = inputs.supports_enable_thinking;
992994
data.prompt = apply(tmpl, inputs);
993995
data.format = COMMON_CHAT_FORMAT_MAGISTRAL;
994996
data.preserved_tokens = {
@@ -1069,6 +1071,7 @@ static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
10691071

10701072
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
10711073
common_chat_params data;
1074+
data.supports_enable_thinking = inputs.supports_enable_thinking;
10721075

10731076
auto adjusted_messages = json::array();
10741077
for (const auto & msg : inputs.messages) {
@@ -1202,6 +1205,7 @@ static void expect_tool_parameters(const std::string & name, const json & parame
12021205
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
12031206
auto builtin_tools = json::array();
12041207
common_chat_params data;
1208+
data.supports_enable_thinking = inputs.supports_enable_thinking;
12051209
if (!inputs.tools.is_null()) {
12061210
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
12071211
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -1281,6 +1285,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
12811285

12821286
static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
12831287
common_chat_params data;
1288+
data.supports_enable_thinking = inputs.supports_enable_thinking;
12841289

12851290
// Generate the prompt using the apply() function with the template
12861291
data.prompt = apply(tmpl, inputs);
@@ -1342,6 +1347,7 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
13421347

13431348
static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
13441349
common_chat_params data;
1350+
data.supports_enable_thinking = inputs.supports_enable_thinking;
13451351

13461352
// Generate the prompt using the apply() function with the template
13471353
data.prompt = apply(tmpl, inputs);
@@ -1466,6 +1472,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
14661472

14671473
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
14681474
common_chat_params data;
1475+
data.supports_enable_thinking = inputs.supports_enable_thinking;
14691476
auto prompt = apply(tmpl, inputs);
14701477

14711478
// Hacks to fix the official (broken) prompt.
@@ -1540,6 +1547,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
15401547

15411548
static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
15421549
common_chat_params data;
1550+
data.supports_enable_thinking = inputs.supports_enable_thinking;
15431551

15441552
// Pass thinking context for DeepSeek V3.1 template
15451553
json additional_context = {
@@ -1685,6 +1693,7 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
16851693

16861694
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
16871695
common_chat_params data;
1696+
data.supports_enable_thinking = inputs.supports_enable_thinking;
16881697
auto prompt = apply(tmpl, inputs);
16891698

16901699
// Check if we need to replace the return token with end token during
@@ -1904,6 +1913,7 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
19041913
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
19051914
LOG_DBG("%s\n", __func__);
19061915
common_chat_params data;
1916+
data.supports_enable_thinking = inputs.supports_enable_thinking;
19071917
const std::optional<json> tools_override = json();
19081918
const std::optional<json> additional_context = json {
19091919
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
@@ -1962,6 +1972,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
19621972
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
19631973
// If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
19641974
common_chat_params data;
1975+
data.supports_enable_thinking = inputs.supports_enable_thinking;
19651976
data.prompt = apply(tmpl, inputs);
19661977
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
19671978
if (inputs.tools.is_array() && !inputs.tools.empty()) {
@@ -2038,6 +2049,7 @@ static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder)
20382049
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
20392050
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
20402051
common_chat_params data;
2052+
data.supports_enable_thinking = inputs.supports_enable_thinking;
20412053

20422054
if (!inputs.tools.is_null()) {
20432055
std::string python_code_argument_name;
@@ -2121,6 +2133,7 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
21212133

21222134
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
21232135
common_chat_params data;
2136+
data.supports_enable_thinking = inputs.supports_enable_thinking;
21242137

21252138
json extra_context = json {
21262139
{"enable_thinking", inputs.enable_thinking},
@@ -2314,6 +2327,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
23142327

23152328
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
23162329
common_chat_params data;
2330+
data.supports_enable_thinking = inputs.supports_enable_thinking;
23172331

23182332
// Pass thinking context for Granite template
23192333
json additional_context = {
@@ -2588,6 +2602,7 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
25882602

25892603
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
25902604
common_chat_params data;
2605+
data.supports_enable_thinking = inputs.supports_enable_thinking;
25912606
data.prompt = apply(tmpl, inputs);
25922607
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
25932608
data.grammar_lazy = false;
@@ -2600,18 +2615,20 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
26002615
data.grammar = inputs.grammar;
26012616
}
26022617

2603-
static constexpr size_t think_tag_len = 7; // strlen("<think>")
2604-
size_t prompt_trimmed_size = data.prompt.size();
2605-
while (prompt_trimmed_size > 0 &&
2606-
std::isspace(static_cast<unsigned char>(data.prompt[prompt_trimmed_size - 1]))) {
2607-
--prompt_trimmed_size;
2608-
}
2609-
if (prompt_trimmed_size >= think_tag_len &&
2610-
data.prompt.compare(prompt_trimmed_size - think_tag_len, think_tag_len, "<think>") == 0) {
2611-
if (!inputs.enable_thinking) {
2612-
data.prompt += "</think>";
2613-
} else {
2614-
data.thinking_forced_open = true;
2618+
if (inputs.supports_enable_thinking) {
2619+
static constexpr size_t think_tag_len = 7; // strlen("<think>")
2620+
size_t prompt_trimmed_size = data.prompt.size();
2621+
while (prompt_trimmed_size > 0 &&
2622+
std::isspace(static_cast<unsigned char>(data.prompt[prompt_trimmed_size - 1]))) {
2623+
--prompt_trimmed_size;
2624+
}
2625+
if (prompt_trimmed_size >= think_tag_len &&
2626+
data.prompt.compare(prompt_trimmed_size - think_tag_len, think_tag_len, "<think>") == 0) {
2627+
if (!inputs.enable_thinking) {
2628+
data.prompt += "</think>";
2629+
} else {
2630+
data.thinking_forced_open = true;
2631+
}
26152632
}
26162633
}
26172634
return data;
@@ -2623,6 +2640,7 @@ static common_chat_params common_chat_params_init_seed_oss(
26232640
const common_chat_templates_inputs & inputs)
26242641
{
26252642
common_chat_params data;
2643+
data.supports_enable_thinking = params.supports_enable_thinking;
26262644
data.prompt = apply(tmpl, params);
26272645
data.format = COMMON_CHAT_FORMAT_SEED_OSS;
26282646
if (string_ends_with(data.prompt, "<seed:think>")) {
@@ -2696,6 +2714,15 @@ static common_chat_params common_chat_templates_apply_jinja(
26962714
params.extra_context[el.first] = json::parse(el.second);
26972715
}
26982716

2717+
{
2718+
auto params_with_thinking = params;
2719+
params_with_thinking.enable_thinking = true;
2720+
auto params_without_thinking = params;
2721+
params_without_thinking.enable_thinking = false;
2722+
params.supports_enable_thinking =
2723+
apply(tmpl, params_with_thinking) != apply(tmpl, params_without_thinking);
2724+
}
2725+
26992726
if (!inputs.json_schema.empty()) {
27002727
params.json_schema = json::parse(inputs.json_schema);
27012728
}

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ struct common_chat_params {
144144
std::string grammar;
145145
bool grammar_lazy = false;
146146
bool thinking_forced_open = false;
147+
bool supports_enable_thinking = false;
147148
std::vector<common_grammar_trigger> grammar_triggers;
148149
std::vector<std::string> preserved_tokens;
149150
std::vector<std::string> additional_stops;

tests/test-chat.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1353,13 +1353,51 @@ static void test_template_output_parsers() {
13531353
auto params_no_thinking = common_chat_templates_apply(tmpls.get(), inputs_no_thinking);
13541354
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_no_thinking.format);
13551355
assert_equals(false, params_no_thinking.thinking_forced_open);
1356-
assert_equals(true, string_ends_with(params_no_thinking.prompt, "</think>"));
1356+
assert_equals(false, params_no_thinking.supports_enable_thinking);
1357+
assert_equals(true, string_ends_with(string_strip(params_no_thinking.prompt), "<think>"));
1358+
1359+
auto inputs_with_thinking = inputs_base;
1360+
inputs_with_thinking.enable_thinking = true;
1361+
auto params_with_thinking = common_chat_templates_apply(tmpls.get(), inputs_with_thinking);
1362+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_with_thinking.format);
1363+
assert_equals(false, params_with_thinking.thinking_forced_open);
1364+
assert_equals(false, params_with_thinking.supports_enable_thinking);
1365+
assert_equals(true, string_ends_with(string_strip(params_with_thinking.prompt), "<think>"));
1366+
1367+
assert_equals(false, common_chat_templates_support_enable_thinking(tmpls.get()));
1368+
}
1369+
{
1370+
// Template that conditionally appends <think> when enable_thinking is true.
1371+
static const char * tmpl_str = R"(
1372+
{% for message in messages %}
1373+
<|{{ message.role }}|>
1374+
{{ message.content }}
1375+
{% endfor %}
1376+
{% if add_generation_prompt %}<|assistant|>
1377+
{% if enable_thinking %}<think>{% endif %}
1378+
{% endif %}
1379+
)";
1380+
1381+
auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, tmpl_str));
1382+
1383+
common_chat_templates_inputs inputs_base;
1384+
inputs_base.messages = { message_user };
1385+
inputs_base.add_generation_prompt = true;
1386+
1387+
auto inputs_no_thinking = inputs_base;
1388+
inputs_no_thinking.enable_thinking = false;
1389+
auto params_no_thinking = common_chat_templates_apply(tmpls.get(), inputs_no_thinking);
1390+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_no_thinking.format);
1391+
assert_equals(false, params_no_thinking.thinking_forced_open);
1392+
assert_equals(true, params_no_thinking.supports_enable_thinking);
1393+
assert_equals(false, string_ends_with(string_strip(params_no_thinking.prompt), "<think>"));
13571394

13581395
auto inputs_with_thinking = inputs_base;
13591396
inputs_with_thinking.enable_thinking = true;
13601397
auto params_with_thinking = common_chat_templates_apply(tmpls.get(), inputs_with_thinking);
13611398
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params_with_thinking.format);
13621399
assert_equals(true, params_with_thinking.thinking_forced_open);
1400+
assert_equals(true, params_with_thinking.supports_enable_thinking);
13631401
assert_equals(true, string_ends_with(string_strip(params_with_thinking.prompt), "<think>"));
13641402

13651403
assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get()));

0 commit comments

Comments
 (0)