@@ -33,6 +33,29 @@ struct chat_template_caps {
3333 bool requires_typed_content = false ;
3434};
3535
36+ struct chat_template_inputs {
37+ nlohmann::ordered_json messages;
38+ nlohmann::ordered_json tools;
39+ bool add_generation_prompt = true ;
40+ nlohmann::ordered_json extra_context;
41+ std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
42+ };
43+
44+ struct chat_template_options {
45+ bool apply_polyfills = true ;
46+ bool use_bos_token = true ;
47+ bool use_eos_token = true ;
48+ bool define_strftime_now = true ;
49+
50+ bool polyfill_tools = true ;
51+ bool polyfill_tool_call_examples = true ;
52+ bool polyfill_tool_calls = true ;
53+ bool polyfill_tool_responses = true ;
54+ bool polyfill_system_role = true ;
55+ bool polyfill_object_arguments = true ;
56+ bool polyfill_typed_content = true ;
57+ };
58+
3659class chat_template {
3760
3861 private:
@@ -41,6 +64,7 @@ class chat_template {
4164 std::string bos_token_;
4265 std::string eos_token_;
4366 std::shared_ptr<minja::TemplateNode> template_root_;
67+ std::string tool_call_example_;
4468
4569 std::string try_raw_render (
4670 const nlohmann::ordered_json & messages,
@@ -49,7 +73,18 @@ class chat_template {
4973 const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
5074 {
5175 try {
52- auto prompt = apply (messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false );
76+ chat_template_inputs inputs;
77+ inputs.messages = messages;
78+ inputs.tools = tools;
79+ inputs.add_generation_prompt = add_generation_prompt;
80+ inputs.extra_context = extra_context;
81+ // Use fixed date for tests
82+ inputs.now = std::chrono::system_clock::from_time_t (0 );
83+
84+ chat_template_options opts;
85+ opts.apply_polyfills = false ;
86+
87+ auto prompt = apply (inputs, opts);
5388 // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
5489 return prompt;
5590 } catch (const std::exception & e) {
@@ -176,35 +211,131 @@ class chat_template {
176211 caps_.supports_tool_responses = contains (out, " Some response!" );
177212 caps_.supports_tool_call_id = contains (out, " call_911_" );
178213 }
214+
215+ try {
216+ if (!caps_.supports_tools ) {
217+ const json user_msg {
218+ {" role" , " user" },
219+ {" content" , " Hey" },
220+ };
221+ const json args {
222+ {" arg1" , " some_value" },
223+ };
224+ const json tool_call_msg {
225+ {" role" , " assistant" },
226+ {" content" , nullptr },
227+ {" tool_calls" , json::array ({
228+ {
229+ // TODO: detect if requires numerical id or fixed length == 6 like Nemo
230+ {" id" , " call_1___" },
231+ {" type" , " function" },
232+ {" function" , {
233+ {" name" , " tool_name" },
234+ {" arguments" , (caps_.requires_object_arguments ? args : json (minja::Value (args).dump (-1 , /* to_json= */ true )))},
235+ }},
236+ },
237+ })},
238+ };
239+ std::string prefix, full;
240+ {
241+ chat_template_inputs inputs;
242+ inputs.messages = json::array ({user_msg});
243+ inputs.add_generation_prompt = true ;
244+ prefix = apply (inputs);
245+ }
246+ {
247+ chat_template_inputs inputs;
248+ inputs.messages = json::array ({user_msg, tool_call_msg});
249+ inputs.add_generation_prompt = false ;
250+ full = apply (inputs);
251+ }
252+
253+ if (full.find (prefix) != 0 ) {
254+ if (prefix.rfind (eos_token_) == prefix.size () - eos_token_.size ()) {
255+ prefix = prefix.substr (0 , prefix.size () - eos_token_.size ());
256+ }
257+ }
258+ if (full.find (prefix) != 0 ) {
259+ fprintf (stderr, " Failed to infer a tool call example (possible template bug)\n " );
260+ }
261+ tool_call_example_ = full.substr (prefix.size ());
262+ }
263+ } catch (const std::exception & e) {
264+ fprintf (stderr, " Failed to generate tool call example: %s\n " , e.what ());
265+ }
179266 }
180267
181268 const std::string & source () const { return source_; }
182269 const std::string & bos_token () const { return bos_token_; }
183270 const std::string & eos_token () const { return eos_token_; }
184271 const chat_template_caps & original_caps () const { return caps_; }
185272
273+ // Deprecated, please use the form with chat_template_inputs and chat_template_options
186274 std::string apply (
187275 const nlohmann::ordered_json & messages,
188276 const nlohmann::ordered_json & tools,
189277 bool add_generation_prompt,
190278 const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
191- bool adjust_inputs = true) const
279+ bool apply_polyfills = true)
280+ {
281+ fprintf (stderr, " [%s] Deprecated!\n " , __func__);
282+ chat_template_inputs inputs;
283+ inputs.messages = messages;
284+ inputs.tools = tools;
285+ inputs.add_generation_prompt = add_generation_prompt;
286+ inputs.extra_context = extra_context;
287+ inputs.now = std::chrono::system_clock::now ();
288+
289+ chat_template_options opts;
290+ opts.apply_polyfills = apply_polyfills;
291+
292+ return apply (inputs, opts);
293+ }
294+
295+ std::string apply (
296+ const chat_template_inputs & inputs,
297+ const chat_template_options & opts = chat_template_options()) const
192298 {
193299 json actual_messages;
194300
195- auto needs_adjustments = adjust_inputs && (false
196- || !caps_.supports_system_role
197- || !caps_.supports_tools
198- || !caps_.supports_tool_responses
199- || !caps_.supports_tool_calls
200- || caps_.requires_object_arguments
201- || caps_.requires_typed_content
301+ auto has_tools = inputs.tools .is_array () && !inputs.tools .empty ();
302+ auto has_tool_calls = false ;
303+ auto has_tool_responses = false ;
304+ auto has_string_content = false ;
305+ for (const auto & message : inputs.messages ) {
306+ if (message.contains (" tool_calls" ) && !message[" tool_calls" ].is_null ()) {
307+ has_tool_calls = true ;
308+ }
309+ if (message.contains (" role" ) && message[" role" ] == " tool" ) {
310+ has_tool_responses = true ;
311+ }
312+ if (message.contains (" content" ) && message[" content" ].is_string ()) {
313+ has_string_content = true ;
314+ }
315+ }
316+
317+ auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role ;
318+ auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools ;
319+ auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples ;
320+ auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls ;
321+ auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses ;
322+ auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments ;
323+ auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content ;
324+
325+ auto needs_polyfills = opts.apply_polyfills && (false
326+ || polyfill_system_role
327+ || polyfill_tools
328+ || polyfill_tool_calls
329+ || polyfill_tool_responses
330+ || polyfill_object_arguments
331+ || polyfill_typed_content
202332 );
203- if (needs_adjustments) {
333+
334+ if (needs_polyfills) {
204335 actual_messages = json::array ();
205336
206337 auto add_message = [&](const json & msg) {
207- if (caps_. requires_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
338+ if (polyfill_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
208339 actual_messages.push_back ({
209340 {" role" , msg.at (" role" )},
210341 {" content" , {{
@@ -227,17 +358,25 @@ class chat_template {
227358 pending_system.clear ();
228359 }
229360 };
230- auto needs_tools_in_system = !tools.is_null () && tools.size () > 0 && !caps_.supports_tools ;
231361
232- for (const auto & message_ : needs_tools_in_system ? add_system (messages, " Available tools: " + tools.dump (2 )) : messages) {
362+ json adjusted_messages;
363+ if (polyfill_tools) {
364+ adjusted_messages = add_system (inputs.messages ,
365+ " You can call any of the following tools to satisfy the user's requests: " + minja::Value (inputs.tools ).dump (2 , /* to_json= */ true ) +
366+ (!polyfill_tool_call_example || tool_call_example_.empty () ? " " : " \n\n Example tool call syntax:\n\n " + tool_call_example_));
367+ } else {
368+ adjusted_messages = inputs.messages ;
369+ }
370+
371+ for (const auto & message_ : adjusted_messages) {
233372 auto message = message_;
234373 if (!message.contains (" role" ) || !message.contains (" content" )) {
235374 throw std::runtime_error (" message must have 'role' and 'content' fields: " + message.dump ());
236375 }
237376 std::string role = message.at (" role" );
238377
239378 if (message.contains (" tool_calls" )) {
240- if (caps_. requires_object_arguments || !caps_. supports_tool_calls ) {
379+ if (polyfill_object_arguments || polyfill_tool_calls ) {
241380 for (auto & tool_call : message.at (" tool_calls" )) {
242381 if (tool_call[" type" ] == " function" ) {
243382 auto & function = tool_call.at (" function" );
@@ -252,7 +391,7 @@ class chat_template {
252391 }
253392 }
254393 }
255- if (!caps_. supports_tool_calls ) {
394+ if (polyfill_tool_calls ) {
256395 auto content = message.at (" content" );
257396 auto tool_calls = json::array ();
258397 for (const auto & tool_call : message.at (" tool_calls" )) {
@@ -279,7 +418,7 @@ class chat_template {
279418 message.erase (" tool_calls" );
280419 }
281420 }
282- if (!caps_. supports_tool_responses && role == " tool" ) {
421+ if (polyfill_tool_responses && role == " tool" ) {
283422 message[" role" ] = " user" ;
284423 auto obj = json {
285424 {" tool_response" , {
@@ -296,7 +435,7 @@ class chat_template {
296435 message.erase (" name" );
297436 }
298437
299- if (!message[" content" ].is_null () && !caps_. supports_system_role ) {
438+ if (!message[" content" ].is_null () && polyfill_system_role ) {
300439 std::string content = message.at (" content" );
301440 if (role == " system" ) {
302441 if (!pending_system.empty ()) pending_system += " \n " ;
@@ -315,28 +454,36 @@ class chat_template {
315454 }
316455 add_message (message);
317456 }
318- if (!caps_.supports_system_role ) {
319- flush_sys ();
320- }
457+ flush_sys ();
321458 } else {
322- actual_messages = messages;
459+ actual_messages = inputs. messages ;
323460 }
324461
325462 auto context = minja::Context::make (json ({
326463 {" messages" , actual_messages},
327- {" add_generation_prompt" , add_generation_prompt},
328- {" bos_token" , bos_token_},
329- {" eos_token" , eos_token_},
464+ {" add_generation_prompt" , inputs.add_generation_prompt },
330465 }));
331-
332- if (!tools.is_null ()) {
333- auto tools_val = minja::Value (tools);
334- context->set (" tools" , tools_val);
466+ context->set (" bos_token" , opts.use_bos_token ? bos_token_ : " " );
467+ context->set (" eos_token" , opts.use_eos_token ? eos_token_ : " " );
468+ if (opts.define_strftime_now ) {
469+ auto now = inputs.now ;
470+ context->set (" strftime_now" , Value::callable ([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
471+ args.expectArgs (" strftime_now" , {1 , 1 }, {0 , 0 });
472+ auto format = args.args [0 ].get <std::string>();
473+
474+ auto time = std::chrono::system_clock::to_time_t (now);
475+ auto local_time = *std::localtime (&time);
476+ std::ostringstream ss;
477+ ss << std::put_time (&local_time, format.c_str ());
478+ return ss.str ();
479+ }));
480+ }
481+ if (!inputs.tools .is_null ()) {
482+ context->set (" tools" , minja::Value (inputs.tools ));
335483 }
336- if (!extra_context.is_null ()) {
337- for (auto & kv : extra_context.items ()) {
338- minja::Value val (kv.value ());
339- context->set (kv.key (), val);
484+ if (!inputs.extra_context .is_null ()) {
485+ for (auto & kv : inputs.extra_context .items ()) {
486+ context->set (kv.key (), minja::Value (kv.value ()));
340487 }
341488 }
342489
@@ -353,7 +500,7 @@ class chat_template {
353500 std::string existing_system = messages_with_system.at (0 ).at (" content" );
354501 messages_with_system[0 ] = json {
355502 {" role" , " system" },
356- {" content" , existing_system + " \n " + system_prompt},
503+ {" content" , existing_system + " \n\n " + system_prompt},
357504 };
358505 } else {
359506 messages_with_system.insert (messages_with_system.begin (), json {
0 commit comments