Skip to content

Commit 7eb5202

Browse files
authored
Fix deepseek r1 tool call example polyfill (template newly adds trailing <think>) (#52)
* Fix deepseek r1 tool call example polyfill (their template newly adds trailing <think>) * test tool outputs for common templates * tests: align extra context in c++ w/ python + remove python tojson override
1 parent e259cda commit 7eb5202

File tree

8 files changed

+310
-75
lines changed

8 files changed

+310
-75
lines changed

include/minja/chat-template.hpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,25 @@ class chat_template {
254254
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
255255
full = full.substr(0, eos_pos_last);
256256
}
257-
if (full.find(prefix) != 0) {
257+
size_t common_prefix_length = 0;
258+
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
259+
if (prefix[i] != full[i]) {
260+
break;
261+
}
262+
if (prefix[i] == '<') {
263+
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
264+
// but it removes thinking tags for past messages.
265+
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
266+
continue;
267+
}
268+
common_prefix_length = i + 1;
269+
}
270+
auto example = full.substr(common_prefix_length);
271+
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
258272
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
273+
} else {
274+
tool_call_example_ = example;
259275
}
260-
tool_call_example_ = full.substr(prefix.size());
261276
}
262277
} catch (const std::exception & e) {
263278
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());

scripts/fetch_templates_and_goldens.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ def raise_exception(message: str):
4343
raise ValueError(message)
4444

4545

46-
def tojson(eval_ctx, value, indent=None):
47-
return json.dumps(value, indent=indent)
48-
4946
TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')
5047

5148

@@ -114,16 +111,22 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext
114111
# print(out, file=sys.stderr)
115112
return out
116113
except BaseException as e:
117-
# print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True)
114+
# print(f"Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True)
118115
return ""
119116

120-
def __init__(self, template, known_eos_tokens, env=None):
117+
def __init__(self, template, env=None, filters=None, global_functions=None):
121118
if not env:
122119
env = jinja2.Environment(
123120
trim_blocks=True,
124121
lstrip_blocks=True,
125122
extensions=[jinja2.ext.loopcontrols]
126123
)
124+
if filters:
125+
for name, func in filters.items():
126+
env.filters[name] = func
127+
if global_functions:
128+
for name, func in global_functions.items():
129+
env.globals[name] = func
127130
self.env = env
128131
self.template = env.from_string(template)
129132

@@ -243,15 +246,24 @@ def make_tool_call(tool_name, arguments):
243246
}
244247
prefix = self.try_raw_render([user_msg], add_generation_prompt=True)
245248
full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False)
246-
if not full.startswith(prefix):
247-
for known_eos_token in known_eos_tokens:
248-
prefix = prefix.rstrip()
249-
if prefix.endswith(known_eos_token):
250-
prefix = prefix[:-len(known_eos_token)]
251-
break
252-
if not full.startswith(prefix):
249+
250+
common_prefix_length = 0
251+
for i in range(min(len(prefix), len(full))):
252+
if prefix[i] != full[i]:
253+
break
254+
if prefix[i] == '<':
255+
# DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
256+
# but it removes thinking tags for past messages.
257+
# The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
258+
continue
259+
common_prefix_length = i + 1
260+
261+
example = full[common_prefix_length:]
262+
if "tool_name" not in example and "some_value" not in example:
253263
print("Failed to infer a tool call example (possible template bug)", file=sys.stderr)
254-
self.tool_call_example = full[len(prefix):]
264+
else:
265+
self.tool_call_example = example
266+
255267
except Exception as e:
256268
print(f"Failed to generate tool call example: {e}", file=sys.stderr)
257269

@@ -321,7 +333,11 @@ def apply(self, context):
321333
message['content'] = [{"type": "text", "text": message['content']}]
322334

323335
try:
324-
return self.template.render(**context)
336+
out = self.template.render(**context)
337+
out = out.replace("\\u0027", "'")
338+
out = out.replace('&#34;', '"')
339+
out = out.replace('&#39;', "'")
340+
return out
325341
except Exception as e1:
326342
for message in context['messages']:
327343
if message.get("content") is None:
@@ -350,21 +366,14 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c
350366
async with aiofiles.open(template_file, 'w') as f:
351367
await f.write(template_src)
352368

353-
known_eos_tokens = [
354-
"<|END_OF_TURN_TOKEN|>",
355-
"<end_of_turn>",
356-
"</s>",
357-
"<|im_end|>",
358-
"<|eom_id|>",
359-
"<|eot_id|>",
360-
"<|end▁of▁sentence|>",
361-
]
362-
363-
template = chat_template(template_src, known_eos_tokens)
364-
template.env.filters['safe'] = lambda x: x
365-
template.env.filters['tojson'] = tojson
366-
template.env.globals['raise_exception'] = raise_exception
367-
template.env.globals['strftime_now'] = strftime_now
369+
template = chat_template(template_src,
370+
filters={
371+
'safe': lambda x: x,
372+
},
373+
global_functions={
374+
'raise_exception': raise_exception,
375+
'strftime_now': strftime_now,
376+
})
368377
caps = template.original_caps
369378

370379
if not context_files:

tests/CMakeLists.txt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,8 @@ target_link_libraries(test-polyfills PRIVATE
3131
)
3232
if (NOT CMAKE_CROSSCOMPILING)
3333
gtest_discover_tests(test-syntax)
34-
endif()
35-
36-
if (NOT CMAKE_CROSSCOMPILING)
37-
gtest_discover_tests(test-syntax)
38-
gtest_discover_tests(test-polyfills)
34+
add_test(NAME test-polyfills COMMAND test-polyfills)
35+
set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
3936
endif()
4037

4138
add_executable(test-capabilities test-capabilities.cpp)
@@ -82,6 +79,7 @@ set(MODEL_IDS
8279
MiniMaxAI/MiniMax-Text-01
8380
indischepartij/MiniCPM-3B-OpenHermes-2.5-v2
8481
mattshumer/Reflection-Llama-3.1-70B
82+
meetkai/functionary-medium-v3.1
8583
meetkai/functionary-medium-v3.2
8684
meta-llama/Llama-3.1-8B-Instruct # Gated
8785
meta-llama/Llama-3.2-3B-Instruct # Gated

tests/contexts/simple.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
],
1212
"add_generation_prompt": true,
1313
"bos_token": "<|startoftext|>",
14-
"eos_token": "<|endoftext|>"
14+
"eos_token": "<|endoftext|>",
15+
"tools_in_user_message": false
1516
}

tests/contexts/system.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
],
1616
"add_generation_prompt": true,
1717
"bos_token": "<|startoftext|>",
18-
"eos_token": "<|endoftext|>"
18+
"eos_token": "<|endoftext|>",
19+
"tools_in_user_message": false
1920
}

tests/contexts/tool_use.json

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
"add_generation_prompt": true,
8989
"bos_token": "<|startoftext|>",
9090
"eos_token": "<|endoftext|>",
91+
"tools_in_user_message": false,
9192
"builtin_tools": [
9293
"wolfram_alpha",
9394
"brave_search"
@@ -96,72 +97,72 @@
9697
"todays_date": "2024-09-03",
9798
"tools": [
9899
{
99-
"type": "function",
100100
"function": {
101-
"name": "ipython",
102101
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
102+
"name": "ipython",
103103
"parameters": {
104-
"type": "object",
105104
"properties": {
106105
"code": {
107-
"type": "string",
108-
"description": "The code to run in the ipython interpreter."
106+
"description": "The code to run in the ipython interpreter.",
107+
"type": "string"
109108
}
110109
},
111-
"required": ["code"]
110+
"required": ["code"],
111+
"type": "object"
112112
}
113-
}
113+
},
114+
"type": "function"
114115
},
115116
{
116-
"type": "function",
117117
"function": {
118-
"name": "brave_search",
119118
"description": "Executes a web search with Brave.",
119+
"name": "brave_search",
120120
"parameters": {
121-
"type": "object",
122121
"properties": {
123122
"query": {
124-
"type": "string",
125-
"description": "The query to search for."
123+
"description": "The query to search for.",
124+
"type": "string"
126125
}
127126
},
128-
"required": ["query"]
127+
"required": ["query"],
128+
"type": "object"
129129
}
130-
}
130+
},
131+
"type": "function"
131132
},
132133
{
133-
"type": "function",
134134
"function": {
135-
"name": "wolfram_alpha",
136135
"description": "Executes a query with Wolfram Alpha.",
136+
"name": "wolfram_alpha",
137137
"parameters": {
138-
"type": "object",
139138
"properties": {
140139
"query": {
141-
"type": "string",
142-
"description": "The query to execute."
140+
"description": "The query to execute.",
141+
"type": "string"
143142
}
144143
},
145-
"required": ["query"]
144+
"required": ["query"],
145+
"type": "object"
146146
}
147-
}
147+
},
148+
"type": "function"
148149
},
149150
{
150-
"type": "function",
151151
"function": {
152-
"name": "test",
153152
"description": "Runs a test.",
153+
"name": "test",
154154
"parameters": {
155-
"type": "object",
156155
"properties": {
157156
"condition": {
158-
"type": "boolean",
159-
"description": "The condition to test."
157+
"description": "The condition to test.",
158+
"type": "boolean"
160159
}
161160
},
162-
"required": ["condition"]
161+
"required": ["condition"],
162+
"type": "object"
163163
}
164-
}
164+
},
165+
"type": "function"
165166
}
166167
]
167168
}

0 commit comments

Comments
 (0)