Skip to content

Commit e598e7a

Browse files
author
ochafik
committed
sync: minja (google/minja#52)
1 parent 95cddfd commit e598e7a

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

common/chat-template.hpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,16 +249,30 @@ class chat_template {
249249
inputs.add_generation_prompt = false;
250250
full = apply(inputs);
251251
}
252-
253-
if (full.find(prefix) != 0) {
254-
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
255-
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
252+
auto eos_pos_last = full.rfind(eos_token_);
253+
if (eos_pos_last == prefix.size() - eos_token_.size() ||
254+
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
255+
full = full.substr(0, eos_pos_last);
256+
}
257+
size_t common_prefix_length = 0;
258+
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
259+
if (prefix[i] != full[i]) {
260+
break;
256261
}
262+
if (prefix[i] == '<') {
263+
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
264+
// but it removes thinking tags for past messages.
265+
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
266+
continue;
267+
}
268+
common_prefix_length = i + 1;
257269
}
258-
if (full.find(prefix) != 0) {
270+
auto example = full.substr(common_prefix_length);
271+
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
259272
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
273+
} else {
274+
tool_call_example_ = example;
260275
}
261-
tool_call_example_ = full.substr(prefix.size());
262276
}
263277
} catch (const std::exception & e) {
264278
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
@@ -363,7 +377,7 @@ class chat_template {
363377
if (polyfill_tools) {
364378
adjusted_messages = add_system(inputs.messages,
365379
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
366-
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
380+
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
367381
} else {
368382
adjusted_messages = inputs.messages;
369383
}

0 commit comments

Comments
 (0)