Skip to content

Commit c3768f4

Browse files
committed
Add guards against stripped reasoning
1 parent e403844 commit c3768f4

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

tools/main/main.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,23 @@ class chat_formatter {
136136
params_(params) {}
137137

138138
std::string operator()(const std::string & role, const std::string & content) {
139+
if (role == "user") {
140+
formatted_cumulative_.clear(); // Needed if template strips reasoning
141+
}
142+
139143
common_chat_msg new_msg;
144+
if (syntax_) {
145+
new_msg = common_chat_parse(content, false, *syntax_);
146+
} else {
147+
new_msg.content = content;
148+
}
140149
new_msg.role = role;
141-
new_msg.content = content;
150+
142151
chat_msgs_.push_back(new_msg);
143152

144153
common_chat_templates_inputs cinputs;
154+
cinputs.messages.assign(chat_msgs_.cbegin(), chat_msgs_.cend());
145155
cinputs.use_jinja = params_.use_jinja;
146-
cinputs.messages = chat_msgs_;
147156
cinputs.add_generation_prompt = (role == "user");
148157
cinputs.reasoning_format = params_.reasoning_format;
149158

@@ -154,16 +163,29 @@ class chat_formatter {
154163

155164
common_chat_params cparams = common_chat_templates_apply(chat_templates_.get(), cinputs);
156165

157-
if (!partial_formatter_ptr_ && params_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
158-
common_chat_syntax chat_syntax;
159-
chat_syntax.format = cparams.format;
160-
chat_syntax.reasoning_format = params_.reasoning_format;
161-
chat_syntax.thinking_forced_open = cparams.thinking_forced_open;
162-
chat_syntax.parse_tool_calls = false;
163-
partial_formatter_ptr_ = std::make_unique<partial_formatter>(chat_syntax);
166+
if (!syntax_) {
167+
syntax_.reset(new common_chat_syntax);
168+
syntax_->format = cparams.format;
169+
syntax_->reasoning_format = params_.reasoning_format;
170+
syntax_->thinking_forced_open = cparams.thinking_forced_open;
171+
syntax_->parse_tool_calls = false;
172+
}
173+
174+
bool use_partial_formatter = params_.reasoning_format != COMMON_REASONING_FORMAT_NONE;
175+
if (!partial_formatter_ptr_ && use_partial_formatter) {
176+
partial_formatter_ptr_ = std::make_unique<partial_formatter>(*syntax_);
177+
}
178+
179+
std::string formatted;
180+
if (formatted_cumulative_.size() > cparams.prompt.size()) {
181+
LOG_WRN("template cumulative size was reduced from \"%zu\" to \"%zu\" "
182+
"likely due to template's removal of message reasoning.\n",
183+
formatted_cumulative_.size(), cparams.prompt.size());
184+
185+
} else {
186+
formatted = cparams.prompt.substr(formatted_cumulative_.size());
164187
}
165188

166-
std::string formatted = cparams.prompt.substr(formatted_cumulative_.size());
167189
formatted_cumulative_ = cparams.prompt;
168190

169191
LOG_DBG("formatted: '%s'\n", formatted.c_str());
@@ -177,6 +199,7 @@ class chat_formatter {
177199
std::vector<common_chat_msg> & chat_msgs_;
178200
const common_chat_templates_ptr & chat_templates_;
179201
const common_params & params_;
202+
std::unique_ptr<common_chat_syntax> syntax_;
180203
std::unique_ptr<partial_formatter> partial_formatter_ptr_;
181204
std::string formatted_cumulative_;
182205
};

0 commit comments

Comments
 (0)