|
17 | 17 | #include <string> |
18 | 18 | #include <vector> |
19 | 19 |
|
| 20 | +using json = nlohmann::ordered_json; |
| 21 | + |
20 | 22 | static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { |
21 | 23 | auto time = std::chrono::system_clock::to_time_t(now); |
22 | 24 | auto local_time = *std::localtime(&time); |
@@ -140,6 +142,7 @@ struct templates_params { |
140 | 142 | bool add_generation_prompt = true; |
141 | 143 | bool enable_thinking = true; |
142 | 144 | std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); |
| 145 | + json extra_context; |
143 | 146 | }; |
144 | 147 |
|
145 | 148 | common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { |
@@ -720,16 +723,23 @@ static void foreach_function(const json & tools, const std::function<void(const |
720 | 723 |
|
721 | 724 | static std::string apply( |
722 | 725 | const common_chat_template & tmpl, |
723 | | - const nlohmann::ordered_json & messages, |
724 | | - const nlohmann::ordered_json & tools, |
725 | | - bool add_generation_prompt, |
726 | | - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) |
| 726 | + const struct templates_params & inputs, |
| 727 | + const std::optional<json> & messages_override = std::nullopt, |
| 728 | + const std::optional<json> & tools_override = std::nullopt, |
| 729 | + const std::optional<json> & additional_context = std::nullopt) |
727 | 730 | { |
728 | 731 | minja::chat_template_inputs tmpl_inputs; |
729 | | - tmpl_inputs.messages = messages; |
730 | | - tmpl_inputs.tools = tools; |
731 | | - tmpl_inputs.add_generation_prompt = add_generation_prompt; |
732 | | - tmpl_inputs.extra_context = extra_context; |
| 732 | + tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; |
| 733 | + if (tools_override) { |
| 734 | + tmpl_inputs.tools = *tools_override; |
| 735 | + } else { |
| 736 | + tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; |
| 737 | + } |
| 738 | + tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; |
| 739 | + tmpl_inputs.extra_context = inputs.extra_context; |
| 740 | + if (additional_context) { |
| 741 | + tmpl_inputs.extra_context.merge_patch(*additional_context); |
| 742 | + } |
733 | 743 | // TODO: add flag to control date/time, if only for testing purposes. |
734 | 744 | // tmpl_inputs.now = std::chrono::system_clock::now(); |
735 | 745 |
|
@@ -828,7 +838,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp |
828 | 838 | inputs.messages, |
829 | 839 | "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); |
830 | 840 |
|
831 | | - data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 841 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); |
832 | 842 | data.format = COMMON_CHAT_FORMAT_GENERIC; |
833 | 843 | return data; |
834 | 844 | } |
@@ -904,7 +914,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat |
904 | 914 | data.preserved_tokens = { |
905 | 915 | "[TOOL_CALLS]", |
906 | 916 | }; |
907 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 917 | + data.prompt = apply(tmpl, inputs); |
908 | 918 | data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; |
909 | 919 | return data; |
910 | 920 | } |
@@ -934,7 +944,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ |
934 | 944 | adjusted_messages.push_back(msg); |
935 | 945 | } |
936 | 946 | } |
937 | | - data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); |
| 947 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); |
938 | 948 | data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; |
939 | 949 | if (string_ends_with(data.prompt, "<|START_THINKING|>")) { |
940 | 950 | if (!inputs.enable_thinking) { |
@@ -1122,7 +1132,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te |
1122 | 1132 | } else { |
1123 | 1133 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
1124 | 1134 | } |
1125 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { |
| 1135 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { |
1126 | 1136 | {"date_string", format_time(inputs.now, "%d %b %Y")}, |
1127 | 1137 | {"tools_in_user_message", false}, |
1128 | 1138 | {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, |
@@ -1187,7 +1197,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w |
1187 | 1197 |
|
1188 | 1198 | static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { |
1189 | 1199 | common_chat_params data; |
1190 | | - auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1200 | + auto prompt = apply(tmpl, inputs); |
1191 | 1201 |
|
1192 | 1202 | // Hacks to fix the official (broken) prompt. |
1193 | 1203 | // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, |
@@ -1282,7 +1292,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { |
1282 | 1292 | static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { |
1283 | 1293 | LOG_DBG("%s\n", __func__); |
1284 | 1294 | common_chat_params data; |
1285 | | - data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { |
| 1295 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json { |
1286 | 1296 | {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, |
1287 | 1297 | {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, |
1288 | 1298 | }); |
@@ -1338,7 +1348,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ |
1338 | 1348 | // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar |
1339 | 1349 | // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. |
1340 | 1350 | common_chat_params data; |
1341 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1351 | + data.prompt = apply(tmpl, inputs); |
1342 | 1352 | data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; |
1343 | 1353 | if (inputs.tools.is_array() && !inputs.tools.empty()) { |
1344 | 1354 | data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; |
@@ -1465,7 +1475,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con |
1465 | 1475 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
1466 | 1476 | } |
1467 | 1477 |
|
1468 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1478 | + data.prompt = apply(tmpl, inputs); |
1469 | 1479 | // TODO: if (has_raw_python) |
1470 | 1480 | return data; |
1471 | 1481 | } |
@@ -1498,14 +1508,15 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser |
1498 | 1508 | static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { |
1499 | 1509 | common_chat_params data; |
1500 | 1510 |
|
1501 | | - json additional_context = { |
| 1511 | + json extra_context = json { |
1502 | 1512 | {"enable_thinking", inputs.enable_thinking}, |
1503 | 1513 | }; |
| 1514 | + extra_context.update(inputs.extra_context); |
1504 | 1515 |
|
1505 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); |
| 1516 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, extra_context); |
1506 | 1517 | data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; |
1507 | 1518 | if (string_ends_with(data.prompt, "<think>\n")) { |
1508 | | - if (!inputs.enable_thinking) { |
| 1519 | + if (!extra_context["enable_thinking"]) { |
1509 | 1520 | data.prompt += "</think>"; |
1510 | 1521 | } else { |
1511 | 1522 | data.thinking_forced_open = true; |
@@ -1691,7 +1702,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { |
1691 | 1702 |
|
1692 | 1703 | static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { |
1693 | 1704 | common_chat_params data; |
1694 | | - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1705 | + data.prompt = apply(tmpl, inputs); |
1695 | 1706 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
1696 | 1707 | data.grammar_lazy = false; |
1697 | 1708 | if (!inputs.json_schema.is_null()) { |
@@ -1722,6 +1733,12 @@ static common_chat_params common_chat_templates_apply_jinja( |
1722 | 1733 | params.enable_thinking = inputs.enable_thinking; |
1723 | 1734 | params.grammar = inputs.grammar; |
1724 | 1735 | params.now = inputs.now; |
| 1736 | + |
| 1737 | + params.extra_context = json::object(); |
| 1738 | + for (auto el : inputs.chat_template_kwargs) { |
| 1739 | + params.extra_context[el.first] = json::parse(el.second); |
| 1740 | + } |
| 1741 | + |
1725 | 1742 | if (!inputs.json_schema.empty()) { |
1726 | 1743 | params.json_schema = json::parse(inputs.json_schema); |
1727 | 1744 | } |
|
0 commit comments