|
17 | 17 | #include <string> |
18 | 18 | #include <vector> |
19 | 19 |
|
| 20 | +using json = nlohmann::ordered_json; |
| 21 | + |
20 | 22 | static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { |
21 | 23 | auto time = std::chrono::system_clock::to_time_t(now); |
22 | 24 | auto local_time = *std::localtime(&time); |
@@ -721,16 +723,23 @@ static void foreach_function(const json & tools, const std::function<void(const |
721 | 723 |
|
722 | 724 | static std::string apply( |
723 | 725 | const common_chat_template & tmpl, |
724 | | - const nlohmann::ordered_json & messages, |
725 | | - const nlohmann::ordered_json & tools, |
726 | | - bool add_generation_prompt, |
727 | | - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) |
| 726 | + const struct templates_params & inputs, |
| 727 | + const std::optional<json> & messages_override = std::nullopt, |
| 728 | + const std::optional<json> & tools_override = std::nullopt, |
| 729 | + const std::optional<json> & additional_context = std::nullopt) |
728 | 730 | { |
729 | 731 | minja::chat_template_inputs tmpl_inputs; |
730 | | - tmpl_inputs.messages = messages; |
731 | | - tmpl_inputs.tools = tools; |
732 | | - tmpl_inputs.add_generation_prompt = add_generation_prompt; |
733 | | - tmpl_inputs.extra_context = extra_context; |
| 732 | + tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; |
| 733 | + if (tools_override) { |
| 734 | + tmpl_inputs.tools = *tools_override; |
| 735 | + } else { |
| 736 | + tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; |
| 737 | + } |
| 738 | + tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; |
| 739 | + tmpl_inputs.extra_context = inputs.extra_context; |
| 740 | + if (additional_context) { |
| 741 | + tmpl_inputs.extra_context.merge_patch(*additional_context); |
| 742 | + } |
734 | 743 | // TODO: add flag to control date/time, if only for testing purposes. |
735 | 744 | // tmpl_inputs.now = std::chrono::system_clock::now(); |
736 | 745 |
|
@@ -829,7 +838,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp |
829 | 838 | inputs.messages, |
830 | 839 | "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); |
831 | 840 |
|
832 | | - data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context); |
| 841 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); |
833 | 842 | data.format = COMMON_CHAT_FORMAT_GENERIC; |
834 | 843 | return data; |
835 | 844 | } |
@@ -905,7 +914,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat |
905 | 914 | data.preserved_tokens = { |
906 | 915 | "[TOOL_CALLS]", |
907 | 916 | }; |
908 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 917 | + data.prompt = apply(tmpl, inputs); |
909 | 918 | data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; |
910 | 919 | return data; |
911 | 920 | } |
@@ -935,7 +944,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ |
935 | 944 | adjusted_messages.push_back(msg); |
936 | 945 | } |
937 | 946 | } |
938 | | - data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); |
| 947 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); |
939 | 948 | data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; |
940 | 949 | if (string_ends_with(data.prompt, "<|START_THINKING|>")) { |
941 | 950 | if (!inputs.enable_thinking) { |
@@ -1123,7 +1132,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te |
1123 | 1132 | } else { |
1124 | 1133 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
1125 | 1134 | } |
1126 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { |
| 1135 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { |
1127 | 1136 | {"date_string", format_time(inputs.now, "%d %b %Y")}, |
1128 | 1137 | {"tools_in_user_message", false}, |
1129 | 1138 | {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, |
@@ -1188,7 +1197,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w |
1188 | 1197 |
|
1189 | 1198 | static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { |
1190 | 1199 | common_chat_params data; |
1191 | | - auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1200 | + auto prompt = apply(tmpl, inputs); |
1192 | 1201 |
|
1193 | 1202 | // Hacks to fix the official (broken) prompt. |
1194 | 1203 | // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, |
@@ -1283,7 +1292,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { |
1283 | 1292 | static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { |
1284 | 1293 | LOG_DBG("%s\n", __func__); |
1285 | 1294 | common_chat_params data; |
1286 | | - data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { |
| 1295 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json { |
1287 | 1296 | {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, |
1288 | 1297 | {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, |
1289 | 1298 | }); |
@@ -1339,7 +1348,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ |
1339 | 1348 | // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar |
1340 | 1349 | // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. |
1341 | 1350 | common_chat_params data; |
1342 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1351 | + data.prompt = apply(tmpl, inputs); |
1343 | 1352 | data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; |
1344 | 1353 | if (inputs.tools.is_array() && !inputs.tools.empty()) { |
1345 | 1354 | data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; |
@@ -1466,7 +1475,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con |
1466 | 1475 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
1467 | 1476 | } |
1468 | 1477 |
|
1469 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1478 | + data.prompt = apply(tmpl, inputs); |
1470 | 1479 | // TODO: if (has_raw_python) |
1471 | 1480 | return data; |
1472 | 1481 | } |
@@ -1499,11 +1508,9 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser |
1499 | 1508 | static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { |
1500 | 1509 | common_chat_params data; |
1501 | 1510 |
|
1502 | | - json additional_context = { |
| 1511 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { |
1503 | 1512 | {"enable_thinking", inputs.enable_thinking}, |
1504 | | - }; |
1505 | | - |
1506 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); |
| 1513 | + }); |
1507 | 1514 | data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; |
1508 | 1515 | if (string_ends_with(data.prompt, "<think>\n")) { |
1509 | 1516 | if (!inputs.enable_thinking) { |
@@ -1692,7 +1699,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { |
1692 | 1699 |
|
1693 | 1700 | static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { |
1694 | 1701 | common_chat_params data; |
1695 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context); |
| 1702 | + data.prompt = apply(tmpl, inputs); |
1696 | 1703 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
1697 | 1704 | data.grammar_lazy = false; |
1698 | 1705 | if (!inputs.json_schema.is_null()) { |
|
0 commit comments