66
77const common_grammar_options grammar_options {
88 /* .dotall = */ false ,
9- // / * .compact_spaces = */ false,
10- /* .compact_spaces = */ true ,
9+ /* .compact_spaces = */ false ,
10+ // / * .compact_spaces = */ true,
1111};
1212
1313static bool parse_json (std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
@@ -59,13 +59,11 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
5959 * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
6060 * Aggregates the prefix, suffix and in-between text into the content.
6161 */
62- static common_chat_msg parse_json_tool_calls (const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false ) {
62+ static common_chat_msg parse_json_tool_calls (const json & tools, const std::string& input, const std::optional<std::regex> & trigger_opt, const std:: regex & function_regex, const std::regex & close_regex, bool check_names, bool allow_raw_python = false ) {
6363 std::smatch match;
6464
6565 common_chat_msg result;
6666 result.role = " assistant" ;
67- auto end = input.end ();
68- auto it = input.begin ();
6967
7068 std::vector<std::string> tool_names;
7169 if (check_names) {
@@ -77,6 +75,18 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
7775 }
7876 }
7977
78+ auto end = input.end ();
79+ auto it = input.begin ();
80+
81+ if (trigger_opt) {
82+ if (!std::regex_search (it, end, match, *trigger_opt)) {
83+ result.content = input;
84+ return result;
85+ }
86+ result.content = match.prefix ().str ();
87+ it = match.suffix ().first ;
88+ }
89+
8090 while (it != end) {
8191 std::sregex_iterator rend;
8292 std::sregex_iterator rit (it, end, function_regex);
@@ -142,24 +152,6 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
142152 return result;
143153}
144154
145- static nlohmann::ordered_json add_system (const nlohmann::ordered_json & messages, const std::string & system_prompt) {
146- json messages_with_system = messages;
147-
148- if (messages_with_system.size () > 0 && messages_with_system[0 ].at (" role" ) == " system" ) {
149- std::string existing_system = messages_with_system.at (0 ).at (" content" );
150- messages_with_system[0 ] = json {
151- {" role" , " system" },
152- {" content" , existing_system + " \n " + system_prompt},
153- };
154- } else {
155- messages_with_system.insert (messages_with_system.begin (), json {
156- {" role" , " system" },
157- {" content" , system_prompt},
158- });
159- }
160- return messages_with_system;
161- }
162-
163155class text_chat_parser : public common_chat_parser {
164156public:
165157 std::optional<common_chat_msg> parse_partial (const std::string & input) override {
@@ -291,12 +283,11 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
291283 builder.add_schema (" root" , schema);
292284 }, grammar_options);
293285
294- // TODO: add schema to system prompt.
295- auto tweaked_messages = add_system (
286+ auto tweaked_messages = common_chat_template::add_system (
296287 params.messages ,
297288 " Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n ```json\n " + schema.dump (2 ) + " \n ```" );
298289
299- data.prompt = tmpl.apply (tweaked_messages, params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
290+ data.prompt = tmpl.apply (tweaked_messages, params.tools .empty () ? json () : params.tools , params. add_generation_prompt );
300291 data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
301292 json data = json::parse (input);
302293 common_chat_msg result;
@@ -363,7 +354,7 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
363354 if (params.tool_choice != " required" ) {
364355 data.grammar_triggers .push_back ({" [TOOL_CALLS]" , /* .at_start = */ true });
365356 }
366- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
357+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params. add_generation_prompt );
367358 data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
368359 return parse_prefixed_json_tool_call_array (input, " [TOOL_CALLS]" );
369360 });
@@ -396,14 +387,13 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
396387 builder.add_rule (" root" , string_join (tool_rules, " | " ));
397388 }, grammar_options);
398389 data.additional_stops .push_back (" <|eom_id|>" );
399- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true , {
390+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params. add_generation_prompt , {
400391 {" builtin_tools" , builtin_tools},
401392 });
402393 data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
403394 static std::regex function_regex (" <\\ |python_tag\\ |>\\ {(?:\" type\" : \" function\" , |[\\ s\\ n\\ r]*)\" name\" : \" ([^\" ]+)\" , \" parameters\" : " );
404395 static std::regex close_regex (" \\ }" );
405- auto res = parse_json_tool_calls (params.tools , input, function_regex, close_regex, /* check_names= */ true );
406- return res;
396+ return parse_json_tool_calls (params.tools , input, std::nullopt , function_regex, close_regex, /* check_names= */ true );
407397 });
408398 fprintf (stderr, " Grammar: %s\n " , data.grammar .c_str ());
409399 return data;
@@ -438,17 +428,31 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
438428 builder.add_rule (" root" , string_join (tool_rules, " | " ));
439429 }, grammar_options);
440430 data.additional_stops .push_back (" <|eom_id|>" );
441- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true , {});
431+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params. add_generation_prompt , {});
442432 data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
443433 static std::regex function_regex (" \\ {[\\ s\\ n\\ r]*(?:\" type\" [\\ s\\ n\\ r]*:[\\ s\\ n\\ r]*\" function\" [\\ s\\ n\\ r]*,[\\ s\\ n\\ r]*|[\\ s\\ n\\ r]*)\" name\" [\\ s\\ n\\ r]*:[\\ s\\ n\\ r]*\" ([^\" ]+)\" [\\ s\\ n\\ r]*,[\\ s\\ n\\ r]*\" parameters\" : " );
444434 static std::regex close_regex (" \\ }" );
445- auto res = parse_json_tool_calls (params.tools , input, function_regex, close_regex, /* check_names= */ true );
435+ auto res = parse_json_tool_calls (params.tools , input, std:: nullopt , function_regex, close_regex, /* check_names= */ true );
446436 return res;
447437 });
448438 fprintf (stderr, " Grammar: %s\n " , data.grammar .c_str ());
449439 return data;
450440}
451441
442+ static common_chat_data common_chat_init_deepseek_r1_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
443+ fprintf (stderr, " [%s]\n " , __func__);
444+ common_chat_data data;
445+ data.grammar = " root ::= .*" ;
446+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params.add_generation_prompt );
447+ data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
448+ static std::regex trigger_regex (" <|tool▁calls▁begin|>" );
449+ static std::regex function_regex (" <|tool▁call▁begin|>function<|tool▁sep|>([^<]+)\n ```json\n " );
450+ static std::regex close_regex (" ```<|tool▁call▁end|>" );
451+ return parse_json_tool_calls (params.tools , input, trigger_regex, function_regex, close_regex, /* check_names= */ true );
452+ });
453+ return data;
454+ }
455+
452456static common_chat_data common_chat_init_firefunction_v2_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
453457 fprintf (stderr, " [%s]\n " , __func__);
454458 common_chat_data data;
@@ -481,7 +485,7 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
481485 if (params.tool_choice != " required" ) {
482486 data.grammar_triggers .push_back ({" functools[" , /* .at_start = */ false });
483487 }
484- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
488+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params. add_generation_prompt );
485489 data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
486490 return parse_prefixed_json_tool_call_array (input, " functools[" , /* rstrip_prefix= */ 1 );
487491 });
@@ -519,12 +523,12 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call(const
519523
520524 }, grammar_options);
521525
522- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
526+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params. add_generation_prompt );
523527 data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
524528 static std::regex function_regex (R"( (?:>>>)?(\w+)\n)" );
525529 static std::regex close_regex (R"( $|(?=>>>))" );
526530
527- auto res = parse_json_tool_calls (params.tools , input, function_regex, close_regex, /* check_names= */ true , /* allow_raw_python= */ true );
531+ auto res = parse_json_tool_calls (params.tools , input, std:: nullopt , function_regex, close_regex, /* check_names= */ true , /* allow_raw_python= */ true );
528532 if (res.content .find (" all\n " ) == 0 ) {
529533 res.content = res.content .substr (4 );
530534 }
@@ -587,7 +591,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
587591 }
588592 }, grammar_options);
589593
590- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
594+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params. add_generation_prompt );
591595 data.parser = std::make_unique<monolithic_chat_parser>([params, has_raw_python, python_code_argument_name](const std::string & input) -> common_chat_msg {
592596 // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
593597 static std::regex python_tag_regex (R"( <\|python_tag\|>([\s\S\n]*)$)" );
@@ -608,7 +612,7 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
608612 }
609613 static std::regex function_regex (R"( <function=(\w+)>)" );
610614 static std::regex close_regex (R"( </function>)" );
611- return parse_json_tool_calls (params.tools , input, function_regex, close_regex, /* check_names= */ false , has_raw_python);
615+ return parse_json_tool_calls (params.tools , input, std:: nullopt , function_regex, close_regex, /* check_names= */ false , has_raw_python);
612616 });
613617 return data;
614618}
@@ -640,7 +644,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
640644 }
641645 }, grammar_options);
642646
643- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
647+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params. add_generation_prompt );
644648 data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
645649 try {
646650 std::regex start_pattern (R"( [\n\s]*<tool_call>)" );
@@ -691,7 +695,7 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
691695static common_chat_data common_chat_init_without_tools (const common_chat_template & tmpl, const struct common_chat_params & params) {
692696 fprintf (stderr, " [%s]\n " , __func__);
693697 common_chat_data data;
694- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
698+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params. add_generation_prompt );
695699 data.parser = std::make_unique<text_chat_parser>();
696700 if (!params.json_schema .is_null ()) {
697701 if (!params.grammar .empty ()) {
@@ -733,6 +737,9 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
733737 return common_chat_init_llama_3_2_tool_calls (tmpl, params);
734738 }
735739 }
740+ if (src.find (" <|tool▁calls▁begin|>" ) != std::string::npos) {
741+ return common_chat_init_deepseek_r1_tool_call (tmpl, params);
742+ }
736743 // if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
737744 // TODO: Command-R-Plus
738745 // }
0 commit comments