Skip to content

Commit eaeed7d

Browse files
author
Olivier Chafik
committed
fix trigger of thinking models (must happen after thoughts are closed)
1 parent 6ed8a8f commit eaeed7d

File tree

10 files changed

+372
-192
lines changed

10 files changed

+372
-192
lines changed

common/chat-parser.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
using json = nlohmann::ordered_json;
1414

15-
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, bool extract_reasoning)
16-
: input_(input), is_partial_(is_partial), extract_reasoning_(extract_reasoning)
15+
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_reasoning_syntax & reasoning_syntax)
16+
: input_(input), is_partial_(is_partial), reasoning_syntax_(reasoning_syntax)
1717
{
1818
result_.role = "assistant";
1919

@@ -129,14 +129,17 @@ void common_chat_msg_parser::consume_literal(const std::string & literal) {
129129
}
130130

131131
void common_chat_msg_parser::try_consume_think_tags(const common_regex & start_think_regex, const common_regex & end_think_regex) {
132-
if (extract_reasoning_) {
133-
if (try_consume_regex(start_think_regex)) {
132+
if (reasoning_syntax_.format != COMMON_REASONING_FORMAT_NONE) {
133+
if (reasoning_syntax_.thinking_forced_open || try_consume_regex(start_think_regex)) {
134134
if (auto res = try_find_regex(end_think_regex)) {
135135
result_.reasoning_content = res->prelude;
136136
consume_spaces();
137137
} else {
138138
result_.reasoning_content = consume_rest();
139-
incomplete("Failed to find end of reasoning tag " + end_think_regex.str());
139+
if (!reasoning_syntax_.thinking_forced_open) {
140+
incomplete("Failed to find end of reasoning tag " + end_think_regex.str());
141+
}
142+
return;
140143
}
141144
} else if (auto res = try_find_regex(end_think_regex)) {
142145
result_.reasoning_content = res->prelude;

common/chat-parser.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
#include <string>
99
#include <vector>
1010

11-
using common_string_ranges = std::vector<common_string_range>;
12-
1311
class common_chat_msg_partial_exception : public std::runtime_error {
1412
public:
1513
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
@@ -18,18 +16,17 @@ class common_chat_msg_partial_exception : public std::runtime_error {
1816
class common_chat_msg_parser {
1917
std::string input_;
2018
bool is_partial_;
21-
bool extract_reasoning_;
19+
common_chat_reasoning_syntax reasoning_syntax_;
20+
2221
size_t pos_ = 0;
2322
common_chat_msg result_;
2423
std::string healing_marker_;
2524

2625
public:
27-
common_chat_msg_parser(const std::string & input, bool is_partial, bool extract_reasoning);
28-
26+
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_reasoning_syntax & reasoning_syntax);
2927
const std::string & input() const { return input_; }
3028
const std::string & healing_marker() const { return healing_marker_; }
3129
const bool & is_partial() const { return is_partial_; }
32-
const bool & extract_reasoning() const { return extract_reasoning_; }
3330
const common_chat_msg & result() const { return result_; }
3431

3532
void move_to(size_t pos) {
@@ -70,13 +67,13 @@ class common_chat_msg_parser {
7067

7168
struct find_regex_result {
7269
std::string prelude;
73-
common_string_ranges groups;
70+
std::vector<common_string_range> groups;
7471
};
7572

7673
std::optional<find_regex_result> try_find_regex(const common_regex & regex);
7774

7875
struct consume_regex_result {
79-
common_string_ranges groups;
76+
std::vector<common_string_range> groups;
8077
};
8178
consume_regex_result consume_regex(const common_regex & regex);
8279

common/chat.cpp

Lines changed: 142 additions & 71 deletions
Large diffs are not rendered by default.

common/chat.h

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ struct common_chat_msg {
3737
std::string tool_name;
3838
std::string tool_call_id;
3939

40+
template <class T> T to_json_oaicompat() const;
41+
4042
bool empty() const {
4143
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
4244
}
@@ -54,6 +56,21 @@ struct common_chat_msg {
5456
}
5557
};
5658

59+
struct common_chat_msg_diff {
60+
// std::string reasoning_content_delta;
61+
std::string content_delta;
62+
size_t tool_call_index = std::string::npos;
63+
common_chat_tool_call tool_call_delta;
64+
65+
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
66+
67+
bool operator==(const common_chat_msg_diff & other) const {
68+
return content_delta == other.content_delta
69+
&& tool_call_index == other.tool_call_index
70+
&& tool_call_delta == other.tool_call_delta;
71+
}
72+
};
73+
5774
struct common_chat_tool {
5875
std::string name;
5976
std::string description;
@@ -73,14 +90,11 @@ enum common_chat_format {
7390
COMMON_CHAT_FORMAT_LLAMA_3_X,
7491
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
7592
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
76-
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
7793
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
7894
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
7995
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
8096
COMMON_CHAT_FORMAT_HERMES_2_PRO,
81-
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
8297
COMMON_CHAT_FORMAT_COMMAND_R7B,
83-
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
8498

8599
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
86100
};
@@ -95,19 +109,26 @@ struct common_chat_templates_inputs {
95109
std::vector<common_chat_tool> tools;
96110
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
97111
bool parallel_tool_calls = false;
98-
bool extract_reasoning = true;
112+
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
99113
};
100114

101115
struct common_chat_params {
102116
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
103117
std::string prompt;
104118
std::string grammar;
105119
bool grammar_lazy = false;
120+
bool thinking_forced_open = false;
106121
std::vector<common_grammar_trigger> grammar_triggers;
107122
std::vector<std::string> preserved_tokens;
108123
std::vector<std::string> additional_stops;
109124
};
110125

126+
struct common_chat_reasoning_syntax {
127+
common_reasoning_format format = COMMON_REASONING_FORMAT_NONE;
128+
bool inlined_in_content = false;
129+
bool thinking_forced_open = false;
130+
};
131+
111132
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
112133
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
113134

@@ -145,7 +166,7 @@ std::string common_chat_format_example(
145166
bool use_jinja);
146167

147168
std::string common_chat_format_name(common_chat_format format);
148-
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format, bool is_partial = false);
169+
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format, bool is_partial = false, const common_chat_reasoning_syntax & reasoning_syntax = {});
149170

150171
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
151172

@@ -158,18 +179,3 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
158179
// T can be std::string containing JSON or nlohmann::ordered_json
159180
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
160181
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
161-
162-
struct common_chat_msg_diff {
163-
// std::string reasoning_content_delta;
164-
std::string content_delta;
165-
size_t tool_call_index = std::string::npos;
166-
common_chat_tool_call tool_call_delta;
167-
168-
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
169-
170-
bool operator==(const common_chat_msg_diff & other) const {
171-
return content_delta == other.content_delta
172-
&& tool_call_index == other.tool_call_index
173-
&& tool_call_delta == other.tool_call_delta;
174-
}
175-
};

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ enum common_grammar_trigger_type {
114114
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
115115
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
116116
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
117-
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
117+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
118118
};
119119

120120
struct common_grammar_trigger {

common/regex-partial.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,34 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
2222
common_regex_match res;
2323
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
2424
for (size_t i = 0; i < match.size(); ++i) {
25-
common_string_range group;
26-
group.begin = pos + match.position(i);
27-
group.end = group.begin + match.length(i);
28-
res.groups.push_back(group);
25+
auto begin = pos + match.position(i);
26+
res.groups.emplace_back(begin, begin + match.length(i));
2927
}
3028
return res;
3129
}
3230
}
3331
std::match_results<std::string::const_reverse_iterator> srmatch;
3432
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
3533
auto group = srmatch[1].str();
36-
auto it = srmatch[1].second.base();
37-
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
38-
if ((!as_match && !at_start_) || it == input.begin()) {
39-
common_regex_match res;
40-
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
41-
//res.groups.push_back({input.substr(position), position, input.size()});
42-
res.groups.push_back({pos + std::distance(input.begin(), it), input.size()});
43-
return res;
34+
if (group.length() != 0) {
35+
auto it = srmatch[1].second.base();
36+
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
37+
if ((!as_match && !at_start_) || it == input.begin()) {
38+
common_regex_match res;
39+
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
40+
auto begin = std::distance(input.begin(), it);
41+
GGML_ASSERT(begin >= 0);
42+
auto end = input.size();//begin + group.length();
43+
GGML_ASSERT(static_cast<size_t>(begin) <= end);
44+
res.groups.push_back({static_cast<size_t>(begin), end});
45+
return res;
46+
}
4447
}
4548
}
4649
return {};
4750
}
4851

49-
/*
52+
/*xz
5053
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
5154
5255
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)

common/regex-partial.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <regex>
44
#include <string>
5+
#include "ggml.h"
56

67
enum common_regex_match_type {
78
COMMON_REGEX_MATCH_TYPE_NONE,
@@ -12,6 +13,11 @@ enum common_regex_match_type {
1213
struct common_string_range {
1314
size_t begin;
1415
size_t end;
16+
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
17+
GGML_ASSERT(begin <= end);
18+
}
19+
// prevent default ctor
20+
common_string_range() = delete;
1521
bool empty() const {
1622
return begin == end;
1723
}

common/sampling.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
160160
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
161161
#endif // LLAMA_USE_LLGUIDANCE
162162
} else {
163-
std::vector<std::string> patterns_at_start;
163+
std::vector<std::string> trigger_patterns;
164164
std::vector<std::string> patterns_anywhere;
165165
std::vector<llama_token> trigger_tokens;
166166
for (const auto & trigger : params.grammar_triggers) {
@@ -172,10 +172,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
172172
break;
173173
}
174174
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
175-
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
176175
{
177-
const auto & pattern = trigger.value;
178-
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
176+
patterns_anywhere.push_back(trigger.value);
177+
break;
178+
}
179+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
180+
{
181+
trigger_patterns.push_back(trigger.value);
179182
break;
180183
}
181184
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -189,10 +192,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
189192
}
190193
}
191194

192-
std::vector<std::string> trigger_patterns;
193-
if (!patterns_at_start.empty()) {
194-
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
195-
}
196195
if (!patterns_anywhere.empty()) {
197196
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
198197
}

docs/function-calling.md

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -329,32 +329,58 @@ Test in CLI (or with any library / software that can use OpenAI-compatible API b
329329

330330
```bash
331331
curl http://localhost:8080/v1/chat/completions -d '{
332-
"model": "gpt-3.5-turbo",
333-
"tools": [
334-
{
335-
"type":"function",
336-
"function":{
337-
"name":"python",
338-
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
339-
"parameters":{
340-
"type":"object",
341-
"properties":{
342-
"code":{
343-
"type":"string",
344-
"description":"The code to run in the ipython interpreter."
332+
"model": "gpt-3.5-turbo",
333+
"tools": [
334+
{
335+
"type":"function",
336+
"function":{
337+
"name":"python",
338+
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
339+
"parameters":{
340+
"type":"object",
341+
"properties":{
342+
"code":{
343+
"type":"string",
344+
"description":"The code to run in the ipython interpreter."
345+
}
346+
},
347+
"required":["code"]
345348
}
346-
},
347-
"required":["code"]
348349
}
349-
}
350-
}
351-
],
352-
"messages": [
353-
{
354-
"role": "user",
355-
"content": "Print a hello world message with python."
356-
}
357-
]
350+
}
351+
],
352+
"messages": [
353+
{
354+
"role": "user",
355+
"content": "Print a hello world message with python."
356+
}
357+
]
358+
}'
359+
360+
361+
curl http://localhost:8080/v1/chat/completions -d '{
362+
"model": "gpt-3.5-turbo",
363+
"messages": [
364+
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
365+
{"role": "user", "content": "What is the weather in Istanbul?"}
366+
],
367+
"tools": [{
368+
"type":"function",
369+
"function":{
370+
"name":"get_current_weather",
371+
"description":"Get the current weather in a given location",
372+
"parameters":{
373+
"type":"object",
374+
"properties":{
375+
"location":{
376+
"type":"string",
377+
"description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`"
378+
}
379+
},
380+
"required":["location"]
381+
}
382+
}
383+
}]
358384
}'
359385
```
360386

0 commit comments

Comments
 (0)