@@ -17,19 +17,26 @@ using json = nlohmann::ordered_json;
1717
1818namespace minja {
1919
20+ struct chat_template_caps {
21+ bool supports_tools = false ;
22+ bool supports_tool_calls = false ;
23+ bool supports_tool_responses = false ;
24+ bool supports_system_role = false ;
25+ bool supports_parallel_tool_calls = false ;
26+ bool supports_tool_call_id = false ;
27+ // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
28+ // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
29+ bool requires_object_arguments = false ;
30+ // CohereForAI/c4ai-command-r-plus simple variant
31+ bool requires_non_null_content = false ;
32+ // MiniMaxAI/MiniMax-Text-01 special
33+ bool requires_typed_content = false ;
34+ };
35+
2036class chat_template {
21- public:
2237
2338 private:
24- bool supports_tools_ = true ;
25- bool supports_tool_calls_ = true ;
26- // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
27- // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
28- bool requires_object_arguments_ = false ;
29- bool requires_typed_content_ = false ;
30- bool supports_system_role_ = true ;
31- bool supports_parallel_tool_calls_ = false ;
32- bool supports_code_interpreter_ = false ;
39+ chat_template_caps caps_;
3340 std::string source_;
3441 std::string bos_token_;
3542 std::string eos_token_;
@@ -43,15 +50,16 @@ class chat_template {
4350 {
4451 try {
4552 auto prompt = apply (messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false );
46- // fprintf(stderr, "Prompt : %s\n", prompt.c_str());
53+ // fprintf(stderr, "try_raw_render : %s\n", prompt.c_str());
4754 return prompt;
4855 } catch (const std::exception & e) {
49- // fprintf(stderr, "Error : %s\n", e.what());
56+ // fprintf(stderr, "try_raw_render error : %s\n", e.what());
5057 return " " ;
5158 }
5259 }
5360
5461 public:
62+
5563 chat_template (const std::string & source, const std::string & bos_token, const std::string & eos_token)
5664 : source_(source), bos_token_(bos_token), eos_token_(eos_token)
5765 {
@@ -60,82 +68,120 @@ class chat_template {
6068 /* .lstrip_blocks = */ true ,
6169 /* .keep_trailing_newline = */ false ,
6270 });
63- supports_tool_calls_ = source.find (" tool_calls" ) != std::string::npos;
64- supports_tools_ =
65- try_raw_render ({
66- {{" role" , " user" }, {" content" , " Hey" }},
67- }, {
68- {
69- {" type" , " function" },
70- {" function" , {
71- {" name" , " some_tool" },
72- {" parameters" , {{" type" , " string" }}},
73- }},
74- },
75- }, false ).find (" some_tool" ) != std::string::npos;
7671
77- requires_object_arguments_ =
78- try_raw_render ({
79- {
80- {" role" , " user" },
81- {" content" , " Hey" }
82- },
83- {
84- {" role" , " assistant" },
85- {" tool_calls" , json::array ({
86- {
87- {" id" , " call_1___" },
88- {" type" , " function" },
89- {" function" , {
90- {" arguments" , {
91- {" code" , " print('Hello, World!')" },
92- }},
93- {" name" , " ipython" },
94- }},
95- },
96- })},
97- }
98- }, {}, false ).find (" {\" code\" : \" print" ) != std::string::npos
99- && try_raw_render ({
100- {
101- {" role" , " user" },
102- {" content" , " Hey" }
103- },
104- {
105- {" role" , " assistant" },
106- {" tool_calls" , json::array ({
107- {
108- {" id" , " call_1___" },
109- {" type" , " function" },
110- {" function" , {
111- {" arguments" , " {\" code\" : \" print('Hello, World!')\" }" },
112- {" name" , " ipython" },
72+ auto contains = [](const std::string & haystack, const std::string & needle) {
73+ return haystack.find (needle) != std::string::npos;
74+ };
75+
76+ const std::string user_needle = " <User Needle>" ;
77+ const std::string sys_needle = " <System Needle>" ;
78+ const json dummy_str_user_msg = {{" role" , " user" }, {" content" , user_needle}};
79+ const json dummy_typed_user_msg = {{" role" , " user" }, {" content" , json::array ({{{" type" , " text" }, {" text" , user_needle}}})}};
80+
81+ caps_.requires_typed_content =
82+ !contains (try_raw_render (json::array ({dummy_str_user_msg}), {}, false ), user_needle)
83+ && contains (try_raw_render (json::array ({dummy_typed_user_msg}), {}, false ), user_needle);
84+
85+ const auto dummy_user_msg = caps_.requires_typed_content
86+ ? dummy_typed_user_msg
87+ : dummy_str_user_msg;
88+ const json needle_system_msg = {
89+ {" role" , " system" },
90+ {" content" , caps_.requires_typed_content ? json::array ({{{" type" , " text" }, {" text" , sys_needle}}}) : json (sys_needle)},
91+ };
92+
93+ caps_.supports_system_role = contains (try_raw_render ({needle_system_msg, dummy_user_msg,}, {}, false ), sys_needle);
94+
95+ auto out = try_raw_render (json::array ({
96+ dummy_user_msg
97+ }), json::array ({
98+ {
99+ {" name" , " some_tool" },
100+ {" type" , " function" },
101+ {" function" , {
102+ {" name" , " some_tool" },
103+ {" description" , " Some tool." },
104+ {" parameters" , {
105+ {" type" , " object" },
106+ {" properties" , {
107+ {" arg" , {
108+ {" type" , " string" },
109+ {" description" , " Some argument." },
113110 }},
114- },
115- })},
116- }
117- }, {}, false ).find (" {\" code\" : \" print" ) == std::string::npos;
111+ }},
112+ {" required" , json::array ({ " arg" })},
113+ }},
114+ }},
115+ },
116+ }), false );
117+ caps_.supports_tools = contains (out, " some_tool" );
118118
119- supports_parallel_tool_calls_ = source.find (" tool_call_id" ) != std::string::npos;
119+ auto make_tool_calls_msg = [&](const json & tool_calls) {
120+ return json {
121+ {" role" , " assistant" },
122+ {" content" , nullptr },
123+ {" tool_calls" , tool_calls},
124+ };
125+ };
126+ auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
127+ return json {
128+ {" id" , " call_1___" },
129+ {" type" , " function" },
130+ {" function" , {
131+ {" arguments" , arguments},
132+ {" name" , tool_name},
133+ }},
134+ };
135+ };
136+ const json dummy_args_obj {{" argument_needle" , " print('Hello, World!')" }};
137+
138+ // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
139+ out = try_raw_render (json::array ({
140+ dummy_user_msg,
141+ make_tool_calls_msg (json::array ({make_tool_call (" ipython" , dummy_args_obj.dump ())})),
142+ }), {}, false );
143+ auto tool_call_renders_str_arguments = contains (out, " \" argument_needle\" :" ) || contains (out, " 'argument_needle':" );
144+ out = try_raw_render (json::array ({
145+ dummy_user_msg,
146+ make_tool_calls_msg (json::array ({make_tool_call (" ipython" , dummy_args_obj)})),
147+ }), {}, false );
148+ auto tool_call_renders_obj_arguments = contains (out, " \" argument_needle\" :" ) || contains (out, " 'argument_needle':" );
120149
121- supports_system_role_ = try_raw_render ({
122- {{" role" , " system" }, {" content" , " <System Needle>" }},
123- {{" role" , " user" }, {" content" , " Hey" }}
124- }, {}, false ).find (" <System Needle>" ) != std::string::npos;
150+ caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
151+ caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
152+ auto out_empty = try_raw_render (json::array ({dummy_user_msg, {{" role" , " assistant" }, {" content" , " " }}}), {}, false );
153+ auto out_null = try_raw_render (json::array ({dummy_user_msg, {{" role" , " assistant" }, {" content" , nullptr }}}), {}, false );
154+ caps_.requires_non_null_content = contains (out_empty, user_needle) && !contains (out_null, user_needle);
125155
126- requires_typed_content_ = try_raw_render ({{{" role" , " user" }, {" content" , " Hey" }}}, {}, false ).find (" Hey" ) == std::string::npos
127- && try_raw_render ({{{" role" , " user" }, {" content" , {{{" type" , " text" }, {" text" , " Hey" }}}}}}, {}, false ).find (" Hey" ) != std::string::npos;
156+ if (caps_.supports_tool_calls ) {
157+ auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json (dummy_args_obj.dump ());
158+ auto tc1 = make_tool_call (" test_tool1" , dummy_args);
159+ auto tc2 = make_tool_call (" test_tool2" , dummy_args);
160+ auto out = try_raw_render (json::array ({
161+ dummy_user_msg,
162+ make_tool_calls_msg (json::array ({tc1, tc2})),
163+ }), {}, false );
164+ caps_.supports_parallel_tool_calls = contains (out, " test_tool1" ) && contains (out, " test_tool2" );
128165
129- supports_code_interpreter_ = source.find (" code_interpreter" ) != std::string::npos;
166+ out = try_raw_render (json::array ({
167+ dummy_user_msg,
168+ make_tool_calls_msg (json::array ({tc1})),
169+ {
170+ {" role" , " tool" },
171+ {" name" , " test_tool1" },
172+ {" content" , " Some response!" },
173+ {" tool_call_id" , " call_911_" },
174+ }
175+ }), {}, false );
176+ caps_.supports_tool_responses = contains (out, " Some response!" );
177+ caps_.supports_tool_call_id = contains (out, " call_911_" );
178+ }
130179 }
131180
132181 const std::string & source () const { return source_; }
133182 const std::string & bos_token () const { return bos_token_; }
134183 const std::string & eos_token () const { return eos_token_; }
135- bool supports_tools () const { return supports_tools_; }
136- bool supports_tool_calls () const { return supports_tool_calls_; }
137- bool supports_parallel_tool_calls () const { return supports_parallel_tool_calls_; }
138- bool requires_object_arguments () const { return requires_object_arguments_; }
184+ const chat_template_caps & original_caps () const { return caps_; }
139185
140186 std::string apply (
141187 const nlohmann::ordered_json & messages,
@@ -145,33 +191,20 @@ class chat_template {
145191 bool adjust_inputs = true) const
146192 {
147193 json actual_messages;
148- json actual_tools;
149-
150- auto has_code_interpreter = false ;
151- for (const auto & tool : tools) {
152- if (tool.contains (" type" ) && tool.at (" type" ) == " code_interpreter" ) {
153- has_code_interpreter = true ;
154- break ;
155- }
156- }
157-
158- if (adjust_inputs && !tools.is_null () && !supports_code_interpreter_ && has_code_interpreter) {
159- actual_tools = json::array ();
160- for (const auto & tool : tools) {
161- if (tool.contains (" type" ) && tool.at (" type" ) == " code_interpreter" && !supports_code_interpreter_) {
162- continue ;
163- }
164- actual_tools.push_back (tool);
165- }
166- } else if (!tools.is_null ()) {
167- actual_tools = tools;
168- }
169194
170- if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || !supports_tool_calls_ || requires_typed_content_)) {
195+ auto needs_adjustments = adjust_inputs && (false
196+ || !caps_.supports_system_role
197+ || !caps_.supports_tools
198+ || !caps_.supports_tool_responses
199+ || !caps_.supports_tool_calls
200+ || caps_.requires_object_arguments
201+ || caps_.requires_typed_content
202+ );
203+ if (needs_adjustments) {
171204 actual_messages = json::array ();
172205
173206 auto add_message = [&](const json & msg) {
174- if (requires_typed_content_ && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
207+ if (caps_. requires_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
175208 actual_messages.push_back ({
176209 {" role" , msg.at (" role" )},
177210 {" content" , {{
@@ -194,7 +227,7 @@ class chat_template {
194227 pending_system.clear ();
195228 }
196229 };
197- auto needs_tools_in_system = !tools.is_null () && tools.size () > 0 && !supports_tools_ ;
230+ auto needs_tools_in_system = !tools.is_null () && tools.size () > 0 && !caps_. supports_tools ;
198231
199232 for (const auto & message_ : needs_tools_in_system ? add_system (messages, " Available tools: " + tools.dump (2 )) : messages) {
200233 auto message = message_;
@@ -204,7 +237,7 @@ class chat_template {
204237 std::string role = message.at (" role" );
205238
206239 if (message.contains (" tool_calls" )) {
207- if (requires_object_arguments_ || !supports_tool_calls_ ) {
240+ if (caps_. requires_object_arguments || !caps_. supports_tool_calls ) {
208241 for (auto & tool_call : message.at (" tool_calls" )) {
209242 if (tool_call[" type" ] == " function" ) {
210243 auto & function = tool_call.at (" function" );
@@ -219,7 +252,7 @@ class chat_template {
219252 }
220253 }
221254 }
222- if (!supports_tool_calls_ ) {
255+ if (!caps_. supports_tool_calls ) {
223256 auto content = message.at (" content" );
224257 auto tool_calls = json::array ();
225258 for (const auto & tool_call : message.at (" tool_calls" )) {
@@ -246,7 +279,7 @@ class chat_template {
246279 message.erase (" tool_calls" );
247280 }
248281 }
249- if (!supports_tools_ && role == " tool" ) {
282+ if (!caps_. supports_tool_responses && role == " tool" ) {
250283 message[" role" ] = " user" ;
251284 auto obj = json {
252285 {" tool_response" , {
@@ -261,7 +294,7 @@ class chat_template {
261294 message.erase (" name" );
262295 }
263296
264- if (!message[" content" ].is_null () && !supports_system_role_ ) {
297+ if (!message[" content" ].is_null () && !caps_. supports_system_role ) {
265298 std::string content = message.at (" content" );
266299 if (role == " system" ) {
267300 if (!pending_system.empty ()) pending_system += " \n " ;
@@ -280,7 +313,7 @@ class chat_template {
280313 }
281314 add_message (message);
282315 }
283- if (!supports_system_role_ ) {
316+ if (!caps_. supports_system_role ) {
284317 flush_sys ();
285318 }
286319 } else {
@@ -295,7 +328,7 @@ class chat_template {
295328 }));
296329
297330 if (!tools.is_null ()) {
298- auto tools_val = minja::Value (actual_tools );
331+ auto tools_val = minja::Value (tools );
299332 context->set (" tools" , tools_val);
300333 }
301334 if (!extra_context.is_null ()) {
@@ -305,7 +338,10 @@ class chat_template {
305338 }
306339 }
307340
308- return template_root_->render (context);
341+ auto ret = template_root_->render (context);
342+ // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
343+ // fprintf(stderr, "apply: %s\n\n", ret.c_str());
344+ return ret;
309345 }
310346
311347 static nlohmann::ordered_json add_system (const nlohmann::ordered_json & messages, const std::string & system_prompt) {
0 commit comments