@@ -33,6 +33,29 @@ struct chat_template_caps {
3333 bool requires_typed_content = false ;
3434};
3535
36+ struct chat_template_inputs {
37+ nlohmann::ordered_json messages;
38+ nlohmann::ordered_json tools;
39+ bool add_generation_prompt = true ;
40+ nlohmann::ordered_json extra_context;
41+ std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
42+ };
43+
44+ struct chat_template_options {
45+ bool apply_polyfills = true ;
46+ bool use_bos_token = true ;
47+ bool use_eos_token = true ;
48+ bool define_strftime_now = true ;
49+
50+ bool polyfill_tools = true ;
51+ bool polyfill_tool_call_examples = true ;
52+ bool polyfill_tool_calls = true ;
53+ bool polyfill_tool_responses = true ;
54+ bool polyfill_system_role = true ;
55+ bool polyfill_object_arguments = true ;
56+ bool polyfill_typed_content = true ;
57+ };
58+
3659class chat_template {
3760
3861 private:
@@ -50,7 +73,18 @@ class chat_template {
5073 const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
5174 {
5275 try {
53- auto prompt = apply (messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false );
76+ chat_template_inputs inputs;
77+ inputs.messages = messages;
78+ inputs.tools = tools;
79+ inputs.add_generation_prompt = add_generation_prompt;
80+ inputs.extra_context = extra_context;
81+ // Use fixed date for tests
82+ inputs.now = std::chrono::system_clock::from_time_t (0 );
83+
84+ chat_template_options opts;
85+ opts.apply_polyfills = false ;
86+
87+ auto prompt = apply (inputs, opts);
5488 // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
5589 return prompt;
5690 } catch (const std::exception & e) {
@@ -178,44 +212,56 @@ class chat_template {
178212 caps_.supports_tool_call_id = contains (out, " call_911_" );
179213 }
180214
181- if (!caps_.supports_tools ) {
182- const json user_msg {
183- {" role" , " user" },
184- {" content" , " Hey" },
185- };
186- const json tool_call_msg {
187- {" role" , " assistant" },
188- {" content" , nullptr },
189- {" tool_calls" , json::array ({
190- {
191- // TODO: detect if requires numerical id or fixed length == 6 like Nemo
192- {" id" , " call_1___" },
193- {" type" , " function" },
194- {" function" , {
195- {" name" , " tool_name" },
196- {" arguments" , (json {
197- {" arg1" , " some_value" },
198- }).dump ()},
199- }},
200- },
201- })},
202- };
203- const json tools;
204- auto prefix = apply (json::array ({user_msg}), tools, /* add_generation_prompt= */ true );
205- auto full = apply (json::array ({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false );
206- if (full.find (prefix) != 0 && prefix.length () > 0 && prefix[prefix.length () - 1 ] == ' \n ' ) {
207- prefix = prefix.substr (0 , prefix.length () - 1 );
208- }
209- if (full.find (prefix) != 0 ) {
210- if (prefix.rfind (eos_token_) == prefix.size () - eos_token_.size ()) {
211- prefix = prefix.substr (0 , prefix.size () - eos_token_.size ());
212- } else {
213- throw std::runtime_error (" prefix not found at start of full: " + prefix + " vs " + full);
215+ try {
216+ if (!caps_.supports_tools ) {
217+ const json user_msg {
218+ {" role" , " user" },
219+ {" content" , " Hey" },
220+ };
221+ const json args {
222+ {" arg1" , " some_value" },
223+ };
224+ const json tool_call_msg {
225+ {" role" , " assistant" },
226+ {" content" , nullptr },
227+ {" tool_calls" , json::array ({
228+ {
229+ // TODO: detect if requires numerical id or fixed length == 6 like Nemo
230+ {" id" , " call_1___" },
231+ {" type" , " function" },
232+ {" function" , {
233+ {" name" , " tool_name" },
234+ {" arguments" , (caps_.requires_object_arguments ? args : json (minja::Value (args).dump (-1 , /* to_json= */ true )))},
235+ }},
236+ },
237+ })},
238+ };
239+ std::string prefix, full;
240+ {
241+ chat_template_inputs inputs;
242+ inputs.messages = json::array ({user_msg});
243+ inputs.add_generation_prompt = true ;
244+ prefix = apply (inputs);
245+ }
246+ {
247+ chat_template_inputs inputs;
248+ inputs.messages = json::array ({user_msg, tool_call_msg});
249+ inputs.add_generation_prompt = false ;
250+ full = apply (inputs);
214251 }
215- } else {
216252
253+ if (full.find (prefix) != 0 ) {
254+ if (prefix.rfind (eos_token_) == prefix.size () - eos_token_.size ()) {
255+ prefix = prefix.substr (0 , prefix.size () - eos_token_.size ());
256+ }
257+ }
258+ if (full.find (prefix) != 0 ) {
259+ fprintf (stderr, " Failed to infer a tool call example (possible template bug)\n " );
260+ }
261+ tool_call_example_ = full.substr (prefix.size ());
217262 }
218- tool_call_example_ = full.substr (prefix.size ());
263+ } catch (const std::exception & e) {
264+ fprintf (stderr, " Failed to generate tool call example: %s\n " , e.what ());
219265 }
220266 }
221267
@@ -225,27 +271,49 @@ class chat_template {
225271 const chat_template_caps & original_caps () const { return caps_; }
226272
227273 std::string apply (
228- const nlohmann::ordered_json & messages,
229- const nlohmann::ordered_json & tools,
230- bool add_generation_prompt,
231- const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
232- bool adjust_inputs = true) const
274+ const chat_template_inputs & inputs,
275+ const chat_template_options & opts = chat_template_options()) const
233276 {
234277 json actual_messages;
235278
236- auto needs_adjustments = adjust_inputs && (false
237- || !caps_.supports_system_role
238- || !caps_.supports_tools
239- || !caps_.supports_tool_responses
240- || !caps_.supports_tool_calls
241- || caps_.requires_object_arguments
242- || caps_.requires_typed_content
279+ auto has_tools = inputs.tools .is_array () && !inputs.tools .empty ();
280+ auto has_tool_calls = false ;
281+ auto has_tool_responses = false ;
282+ auto has_string_content = false ;
283+ for (const auto & message : inputs.messages ) {
284+ if (!message[" tool_calls" ].is_null ()) {
285+ has_tool_calls = true ;
286+ }
287+ if (message[" role" ] == " tool" ) {
288+ has_tool_responses = true ;
289+ }
290+ if (message[" content" ].is_string ()) {
291+ has_string_content = true ;
292+ }
293+ }
294+
295+ auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role ;
296+ auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools ;
297+ auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples ;
298+ auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls ;
299+ auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses ;
300+ auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments ;
301+ auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content ;
302+
303+ auto needs_polyfills = opts.apply_polyfills && (false
304+ || polyfill_system_role
305+ || polyfill_tools
306+ || polyfill_tool_calls
307+ || polyfill_tool_responses
308+ || polyfill_object_arguments
309+ || polyfill_typed_content
243310 );
244- if (needs_adjustments) {
311+
312+ if (needs_polyfills) {
245313 actual_messages = json::array ();
246314
247315 auto add_message = [&](const json & msg) {
248- if (caps_. requires_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
316+ if (polyfill_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
249317 actual_messages.push_back ({
250318 {" role" , msg.at (" role" )},
251319 {" content" , {{
@@ -268,16 +336,14 @@ class chat_template {
268336 pending_system.clear ();
269337 }
270338 };
271- auto needs_tools_in_system = !tools.is_null () && tools.size () > 0 && !caps_.supports_tools ;
272339
273340 json adjusted_messages;
274- if (needs_tools_in_system) {
275- adjusted_messages = add_system (messages,
276- " \n\n "
277- " You can call any of the following tools to satisfy the user's requests: " + tools.dump (2 ) + " \n\n "
278- " Example tool call syntax:\n\n " + tool_call_example_ + " \n\n " );
341+ if (polyfill_tools) {
342+ adjusted_messages = add_system (inputs.messages ,
343+ " You can call any of the following tools to satisfy the user's requests: " + minja::Value (inputs.tools ).dump (2 , /* to_json= */ true ) +
344+ (!polyfill_tool_call_example || tool_call_example_.empty () ? " " : " \n\n Example tool call syntax:\n\n " + tool_call_example_));
279345 } else {
280- adjusted_messages = messages;
346+ adjusted_messages = inputs. messages ;
281347 }
282348
283349 for (const auto & message_ : adjusted_messages) {
@@ -288,7 +354,7 @@ class chat_template {
288354 std::string role = message.at (" role" );
289355
290356 if (message.contains (" tool_calls" )) {
291- if (caps_. requires_object_arguments || !caps_. supports_tool_calls ) {
357+ if (polyfill_object_arguments || polyfill_tool_calls ) {
292358 for (auto & tool_call : message.at (" tool_calls" )) {
293359 if (tool_call[" type" ] == " function" ) {
294360 auto & function = tool_call.at (" function" );
@@ -303,7 +369,7 @@ class chat_template {
303369 }
304370 }
305371 }
306- if (!caps_. supports_tool_calls ) {
372+ if (polyfill_tool_calls ) {
307373 auto content = message.at (" content" );
308374 auto tool_calls = json::array ();
309375 for (const auto & tool_call : message.at (" tool_calls" )) {
@@ -330,7 +396,7 @@ class chat_template {
330396 message.erase (" tool_calls" );
331397 }
332398 }
333- if (!caps_. supports_tool_responses && role == " tool" ) {
399+ if (polyfill_tool_responses && role == " tool" ) {
334400 message[" role" ] = " user" ;
335401 auto obj = json {
336402 {" tool_response" , {
@@ -347,7 +413,7 @@ class chat_template {
347413 message.erase (" name" );
348414 }
349415
350- if (!message[" content" ].is_null () && !caps_. supports_system_role ) {
416+ if (!message[" content" ].is_null () && polyfill_system_role ) {
351417 std::string content = message.at (" content" );
352418 if (role == " system" ) {
353419 if (!pending_system.empty ()) pending_system += " \n " ;
@@ -366,28 +432,40 @@ class chat_template {
366432 }
367433 add_message (message);
368434 }
369- if (!caps_.supports_system_role ) {
370- flush_sys ();
371- }
435+ flush_sys ();
372436 } else {
373- actual_messages = messages;
437+ actual_messages = inputs. messages ;
374438 }
375439
376440 auto context = minja::Context::make (json ({
377441 {" messages" , actual_messages},
378- {" add_generation_prompt" , add_generation_prompt},
379- {" bos_token" , bos_token_},
380- {" eos_token" , eos_token_},
442+ {" add_generation_prompt" , inputs.add_generation_prompt },
381443 }));
382-
383- if (!tools.is_null ()) {
384- auto tools_val = minja::Value (tools);
385- context->set (" tools" , tools_val);
444+ if (opts.use_bos_token ) {
445+ context->set (" bos_token" , bos_token_);
446+ }
447+ if (opts.use_eos_token ) {
448+ context->set (" eos_token" , eos_token_);
449+ }
450+ if (opts.define_strftime_now ) {
451+ auto now = inputs.now ;
452+ context->set (" strftime_now" , Value::callable ([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
453+ args.expectArgs (" strftime_now" , {1 , 1 }, {0 , 0 });
454+ auto format = args.args [0 ].get <std::string>();
455+
456+ auto time = std::chrono::system_clock::to_time_t (now);
457+ auto local_time = *std::localtime (&time);
458+ std::ostringstream ss;
459+ ss << std::put_time (&local_time, format.c_str ());
460+ return ss.str ();
461+ }));
462+ }
463+ if (!inputs.tools .is_null ()) {
464+ context->set (" tools" , minja::Value (inputs.tools ));
386465 }
387- if (!extra_context.is_null ()) {
388- for (auto & kv : extra_context.items ()) {
389- minja::Value val (kv.value ());
390- context->set (kv.key (), val);
466+ if (!inputs.extra_context .is_null ()) {
467+ for (auto & kv : inputs.extra_context .items ()) {
468+ context->set (kv.key (), minja::Value (kv.value ()));
391469 }
392470 }
393471
@@ -404,7 +482,7 @@ class chat_template {
404482 std::string existing_system = messages_with_system.at (0 ).at (" content" );
405483 messages_with_system[0 ] = json {
406484 {" role" , " system" },
407- {" content" , existing_system + " \n " + system_prompt},
485+ {" content" , existing_system + " \n\n " + system_prompt},
408486 };
409487 } else {
410488 messages_with_system.insert (messages_with_system.begin (), json {
0 commit comments