1616#include < string>
1717#include < vector>
1818
19+ using json = nlohmann::ordered_json;
1920
2021static std::string format_time (const std::chrono::system_clock::time_point & now, const std::string & format) {
2122 auto time = std::chrono::system_clock::to_time_t (now);
@@ -721,16 +722,23 @@ static void foreach_function(const json & tools, const std::function<void(const
721722
722723static std::string apply (
723724 const common_chat_template & tmpl,
724- const nlohmann::ordered_json & messages ,
725- const nlohmann::ordered_json & tools ,
726- bool add_generation_prompt ,
727- const nlohmann::ordered_json & extra_context = nlohmann::ordered_json() )
725+ const struct templates_params & inputs ,
726+ const std::optional<json> & messages_override = std:: nullopt ,
727+ const std::optional<json> & tools_override = std:: nullopt ,
728+ const std::optional<json> & additional_context = std:: nullopt )
728729{
729730 minja::chat_template_inputs tmpl_inputs;
730- tmpl_inputs.messages = messages;
731- tmpl_inputs.tools = tools;
732- tmpl_inputs.add_generation_prompt = add_generation_prompt;
733- tmpl_inputs.extra_context = extra_context;
731+ tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages ;
732+ if (tools_override) {
733+ tmpl_inputs.tools = *tools_override;
734+ } else {
735+ tmpl_inputs.tools = inputs.tools .empty () ? json () : inputs.tools ;
736+ }
737+ tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt ;
738+ tmpl_inputs.extra_context = inputs.extra_context ;
739+ if (additional_context) {
740+ tmpl_inputs.extra_context .merge_patch (*additional_context);
741+ }
734742 // TODO: add flag to control date/time, if only for testing purposes.
735743 // tmpl_inputs.now = std::chrono::system_clock::now();
736744
@@ -829,7 +837,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
829837 inputs.messages ,
830838 " Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request" );
831839
832- data.prompt = apply (tmpl, tweaked_messages, inputs. tools . empty () ? json () : inputs. tools , inputs. add_generation_prompt , inputs. extra_context );
840+ data.prompt = apply (tmpl, inputs, /* messages_override= */ tweaked_messages );
833841 data.format = COMMON_CHAT_FORMAT_GENERIC;
834842 return data;
835843}
@@ -901,7 +909,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
901909 data.preserved_tokens = {
902910 " [TOOL_CALLS]" ,
903911 };
904- data.prompt = apply (tmpl, inputs. messages , inputs. tools . empty () ? json () : inputs. tools , inputs. add_generation_prompt );
912+ data.prompt = apply (tmpl, inputs);
905913 data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
906914 return data;
907915}
@@ -926,7 +934,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
926934 adjusted_messages.push_back (msg);
927935 }
928936 }
929- data.prompt = apply (tmpl, adjusted_messages, inputs. tools . empty () ? json () : inputs. tools , inputs. add_generation_prompt , {} );
937+ data.prompt = apply (tmpl, inputs, /* messages_override= */ adjusted_messages );
930938 data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
931939 if (string_ends_with (data.prompt , " <|START_THINKING|>" )) {
932940 if (!inputs.enable_thinking ) {
@@ -1119,7 +1127,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
11191127 } else {
11201128 data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
11211129 }
1122- data.prompt = apply (tmpl, inputs. messages , inputs. tools . empty () ? json () : inputs. tools , inputs. add_generation_prompt , {
1130+ data.prompt = apply (tmpl, inputs, /* messages_override = */ std:: nullopt , /* tools_override= */ std:: nullopt , json {
11231131 {" date_string" , format_time (inputs.now , " %d %b %Y" )},
11241132 {" tools_in_user_message" , false },
11251133 {" builtin_tools" , builtin_tools.empty () ? json () : builtin_tools},
@@ -1181,7 +1189,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
11811189
11821190static common_chat_params common_chat_params_init_deepseek_r1 (const common_chat_template & tmpl, const struct templates_params & inputs) {
11831191 common_chat_params data;
1184- auto prompt = apply (tmpl, inputs. messages , inputs. tools . empty () ? json () : inputs. tools , inputs. add_generation_prompt );
1192+ auto prompt = apply (tmpl, inputs);
11851193
11861194 // Hacks to fix the official (broken) prompt.
11871195 // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
@@ -1272,7 +1280,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
12721280static common_chat_params common_chat_params_init_firefunction_v2 (const common_chat_template & tmpl, const struct templates_params & inputs) {
12731281 LOG_DBG (" %s\n " , __func__);
12741282 common_chat_params data;
1275- data.prompt = apply (tmpl, inputs. messages , /* tools= */ nullptr , inputs. add_generation_prompt , {
1283+ data.prompt = apply (tmpl, inputs, /* messages_override = */ std:: nullopt , /* tools_override= */ json (), json {
12761284 {" datetime" , format_time (inputs.now , " %b %d %Y %H:%M:%S GMT" )},
12771285 {" functions" , json (inputs.tools .empty () ? " " : inputs.tools .dump (2 ))},
12781286 });
@@ -1324,7 +1332,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
13241332 // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
13251333 // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
13261334 common_chat_params data;
1327- data.prompt = apply (tmpl, inputs. messages , inputs. tools . empty () ? json () : inputs. tools , inputs. add_generation_prompt );
1335+ data.prompt = apply (tmpl, inputs);
13281336 data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
13291337 if (inputs.tools .is_array () && !inputs.tools .empty ()) {
13301338 data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1451,7 +1459,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
14511459 data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
14521460 }
14531461
1454- data.prompt = apply (tmpl, inputs. messages , inputs. tools . empty () ? json () : inputs. tools , inputs. add_generation_prompt );
1462+ data.prompt = apply (tmpl, inputs);
14551463 // TODO: if (has_raw_python)
14561464 return data;
14571465}
@@ -1481,11 +1489,9 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
14811489static common_chat_params common_chat_params_init_hermes_2_pro (const common_chat_template & tmpl, const struct templates_params & inputs) {
14821490 common_chat_params data;
14831491
1484- json additional_context = {
1492+ data. prompt = apply (tmpl, inputs, /* messages_override = */ std:: nullopt , /* tools_override= */ std:: nullopt , json {
14851493 {" enable_thinking" , inputs.enable_thinking },
1486- };
1487-
1488- data.prompt = apply (tmpl, inputs.messages , inputs.tools .empty () ? json () : inputs.tools , inputs.add_generation_prompt , additional_context);
1494+ });
14891495 data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
14901496 if (string_ends_with (data.prompt , " <think>\n " )) {
14911497 if (!inputs.enable_thinking ) {
@@ -1672,7 +1678,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16721678
16731679static common_chat_params common_chat_params_init_without_tools (const common_chat_template & tmpl, const struct templates_params & inputs) {
16741680 common_chat_params data;
1675- data.prompt = apply (tmpl, inputs. messages , inputs. tools . empty () ? json () : inputs. tools , inputs. add_generation_prompt , inputs. extra_context );
1681+ data.prompt = apply (tmpl, inputs);
16761682 data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
16771683 data.grammar_lazy = false ;
16781684 if (!inputs.json_schema .is_null ()) {
0 commit comments