|
1 | 1 | #include "chat.h" |
| 2 | +#include "chat-parser.h" |
| 3 | +#include "chat-peg-parser.h" |
2 | 4 | #include "common.h" |
3 | 5 | #include "json-partial.h" |
4 | 6 | #include "json-schema-to-grammar.h" |
@@ -150,6 +152,7 @@ struct templates_params { |
150 | 152 | common_chat_tool_choice tool_choice; |
151 | 153 | json json_schema; |
152 | 154 | bool parallel_tool_calls; |
| 155 | + common_reasoning_format reasoning_format; |
153 | 156 | bool stream; |
154 | 157 | std::string grammar; |
155 | 158 | bool add_generation_prompt = true; |
@@ -594,6 +597,16 @@ common_chat_templates_ptr common_chat_templates_init( |
594 | 597 | "{%- if false %}"); |
595 | 598 | } |
596 | 599 |
|
| 600 | + // TODO @aldehir : this is a temporary fix, pending Minja changes |
| 601 | + // Ref: https://github.com/ggml-org/llama.cpp/pull/17713#issuecomment-3631342664 |
| 602 | + if (default_template_src.find("[TOOL_CALLS]") != std::string::npos |
| 603 | + // search for the error message and patch it |
| 604 | + && default_template_src.find("if (message['content'] is none or") != std::string::npos) { |
| 605 | + string_replace_all(default_template_src, |
| 606 | + "{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}", |
| 607 | + "{%- if false %}"); |
| 608 | + } |
| 609 | + |
597 | 610 | std::string token_bos = bos_token_override; |
598 | 611 | std::string token_eos = eos_token_override; |
599 | 612 | bool add_bos = false; |
@@ -992,6 +1005,118 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat |
992 | 1005 | return data; |
993 | 1006 | } |
994 | 1007 |
|
| 1008 | +static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const struct templates_params & inputs) { |
| 1009 | + common_chat_params data; |
| 1010 | + |
| 1011 | + // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja |
| 1012 | + auto adjusted_messages = json::array(); |
| 1013 | + for (const auto & msg : inputs.messages) { |
| 1014 | + auto role = msg.value("role", ""); |
| 1015 | + if (role != "system" && role != "assistant") { |
| 1016 | + // Only adjust system and assistant messages. Interestingly, the system message may contain thinking. |
| 1017 | + adjusted_messages.push_back(msg); |
| 1018 | + continue; |
| 1019 | + } |
| 1020 | + |
| 1021 | + auto content = json::array(); |
| 1022 | + |
| 1023 | + // If message contains `reasoning_content`, add it as a block of type `thinking` |
| 1024 | + if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { |
| 1025 | + content.push_back({ |
| 1026 | + {"type", "thinking"}, |
| 1027 | + {"thinking", msg.at("reasoning_content").get<std::string>()}, |
| 1028 | + }); |
| 1029 | + } |
| 1030 | + |
| 1031 | + // If message contains `content`, add it as a block of type `text` |
| 1032 | + if (msg.contains("content")) { |
| 1033 | + if (msg.at("content").is_string()) { |
| 1034 | + content.push_back({ |
| 1035 | + {"type", "text"}, |
| 1036 | + {"text", msg.at("content").get<std::string>()}, |
| 1037 | + }); |
| 1038 | + } else if (msg.at("content").is_array()) { |
| 1039 | + auto blocks = msg.at("content"); |
| 1040 | + content.insert(content.end(), blocks.begin(), blocks.end()); |
| 1041 | + } |
| 1042 | + } |
| 1043 | + |
| 1044 | + auto adjusted = msg; |
| 1045 | + adjusted["content"] = content; |
| 1046 | + adjusted.erase("reasoning_content"); |
| 1047 | + adjusted_messages.push_back(adjusted); |
| 1048 | + } |
| 1049 | + |
| 1050 | + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); |
| 1051 | + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; |
| 1052 | + auto include_grammar = true; |
| 1053 | + |
| 1054 | + data.prompt = apply(tmpl, inputs, /* messages_override = */ adjusted_messages); |
| 1055 | + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; |
| 1056 | + data.preserved_tokens = { |
| 1057 | + "[THINK]", |
| 1058 | + "[/THINK]", |
| 1059 | + "[TOOL_CALLS]", |
| 1060 | + "[ARGS]", |
| 1061 | + }; |
| 1062 | + |
| 1063 | + auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { |
| 1064 | + auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); |
| 1065 | + |
| 1066 | + // Response format parser |
| 1067 | + if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { |
| 1068 | + // Ministral wants to emit json surrounded by code fences |
| 1069 | + return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```"; |
| 1070 | + } |
| 1071 | + |
| 1072 | + // Tool call parser |
| 1073 | + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { |
| 1074 | + auto tool_choice = p.choice(); |
| 1075 | + foreach_function(inputs.tools, [&](const json & tool) { |
| 1076 | + const auto & function = tool.at("function"); |
| 1077 | + std::string name = function.at("name"); |
| 1078 | + const auto & schema = function.at("parameters"); |
| 1079 | + |
| 1080 | + tool_choice |= p.rule("tool-" + name, |
| 1081 | + p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") |
| 1082 | + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) |
| 1083 | + ); |
| 1084 | + }); |
| 1085 | + |
| 1086 | + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; |
| 1087 | + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; |
| 1088 | + auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); |
| 1089 | + |
| 1090 | + return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; |
| 1091 | + } |
| 1092 | + |
| 1093 | + // Content only parser |
| 1094 | + include_grammar = false; |
| 1095 | + return reasoning << p.content(p.rest()); |
| 1096 | + }); |
| 1097 | + |
| 1098 | + data.parser = parser.save(); |
| 1099 | + |
| 1100 | + if (include_grammar) { |
| 1101 | + data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; |
| 1102 | + |
| 1103 | + data.grammar = build_grammar([&](const common_grammar_builder & builder) { |
| 1104 | + foreach_function(inputs.tools, [&](const json & tool) { |
| 1105 | + const auto & function = tool.at("function"); |
| 1106 | + auto schema = function.at("parameters"); |
| 1107 | + builder.resolve_refs(schema); |
| 1108 | + }); |
| 1109 | + parser.build_grammar(builder, data.grammar_lazy); |
| 1110 | + }); |
| 1111 | + |
| 1112 | + data.grammar_triggers = { |
| 1113 | + {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"} |
| 1114 | + }; |
| 1115 | + } |
| 1116 | + |
| 1117 | + return data; |
| 1118 | +} |
| 1119 | + |
995 | 1120 | static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { |
996 | 1121 | common_chat_params data; |
997 | 1122 | data.prompt = apply(tmpl, inputs); |
@@ -2346,6 +2471,7 @@ static common_chat_params common_chat_templates_apply_jinja( |
2346 | 2471 | params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); |
2347 | 2472 | params.add_generation_prompt = inputs.add_generation_prompt; |
2348 | 2473 | params.tool_choice = inputs.tool_choice; |
| 2474 | + params.reasoning_format = inputs.reasoning_format; |
2349 | 2475 | params.enable_thinking = inputs.enable_thinking; |
2350 | 2476 | params.grammar = inputs.grammar; |
2351 | 2477 | params.now = inputs.now; |
@@ -2509,6 +2635,13 @@ static common_chat_params common_chat_templates_apply_jinja( |
2509 | 2635 | return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); |
2510 | 2636 | } |
2511 | 2637 |
|
| 2638 | + // Ministral/Mistral Large 3 |
| 2639 | + if (src.find("[SYSTEM_PROMPT]") != std::string::npos && |
| 2640 | + src.find("[TOOL_CALLS]") != std::string::npos && |
| 2641 | + src.find("[ARGS]") != std::string::npos) { |
| 2642 | + return common_chat_params_init_ministral_3(tmpl, params); |
| 2643 | + } |
| 2644 | + |
2512 | 2645 | if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) { |
2513 | 2646 | return common_chat_params_init_magistral(tmpl, params); |
2514 | 2647 | } |
|
0 commit comments