|
1 | 1 | #include "chat.h" |
2 | 2 | #include "chat-parser.h" |
| 3 | +#include "chat-peg-parser.h" |
3 | 4 | #include "common.h" |
4 | 5 | #include "json-partial.h" |
5 | 6 | #include "json-schema-to-grammar.h" |
@@ -150,6 +151,7 @@ struct templates_params { |
150 | 151 | common_chat_tool_choice tool_choice; |
151 | 152 | json json_schema; |
152 | 153 | bool parallel_tool_calls; |
| 154 | + common_reasoning_format reasoning_format; |
153 | 155 | bool stream; |
154 | 156 | std::string grammar; |
155 | 157 | bool add_generation_prompt = true; |
@@ -589,6 +591,16 @@ common_chat_templates_ptr common_chat_templates_init( |
589 | 591 | "{%- if false %}"); |
590 | 592 | } |
591 | 593 |
|
| 594 | + // TODO @aldehir : this is a temporary fix, pending Minja changes |
| 595 | + // Ref: https://github.com/ggml-org/llama.cpp/pull/17713#issuecomment-3631342664 |
| 596 | + if (default_template_src.find("[TOOL_CALLS]") != std::string::npos |
| 597 | + // search for the error message and patch it |
| 598 | + && default_template_src.find("if (message['content'] is none or") != std::string::npos) { |
| 599 | + string_replace_all(default_template_src, |
| 600 | + "{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}", |
| 601 | + "{%- if false %}"); |
| 602 | + } |
| 603 | + |
592 | 604 | std::string token_bos = bos_token_override; |
593 | 605 | std::string token_eos = eos_token_override; |
594 | 606 | bool add_bos = false; |
@@ -987,6 +999,118 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat |
987 | 999 | return data; |
988 | 1000 | } |
989 | 1001 |
|
| 1002 | +static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const struct templates_params & inputs) { |
| 1003 | + common_chat_params data; |
| 1004 | + |
| 1005 | + // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja |
| 1006 | + auto adjusted_messages = json::array(); |
| 1007 | + for (const auto & msg : inputs.messages) { |
| 1008 | + auto role = msg.value("role", ""); |
| 1009 | + if (role != "system" && role != "assistant") { |
| 1010 | + // Only adjust system and assistant messages. Interestingly, the system message may contain thinking. |
| 1011 | + adjusted_messages.push_back(msg); |
| 1012 | + continue; |
| 1013 | + } |
| 1014 | + |
| 1015 | + auto content = json::array(); |
| 1016 | + |
| 1017 | + // If message contains `reasoning_content`, add it as a block of type `thinking` |
| 1018 | + if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { |
| 1019 | + content.push_back({ |
| 1020 | + {"type", "thinking"}, |
| 1021 | + {"thinking", msg.at("reasoning_content").get<std::string>()}, |
| 1022 | + }); |
| 1023 | + } |
| 1024 | + |
| 1025 | + // If message contains `content`, add it as a block of type `text` |
| 1026 | + if (msg.contains("content")) { |
| 1027 | + if (msg.at("content").is_string()) { |
| 1028 | + content.push_back({ |
| 1029 | + {"type", "text"}, |
| 1030 | + {"text", msg.at("content").get<std::string>()}, |
| 1031 | + }); |
| 1032 | + } else if (msg.at("content").is_array()) { |
| 1033 | + auto blocks = msg.at("content"); |
| 1034 | + content.insert(content.end(), blocks.begin(), blocks.end()); |
| 1035 | + } |
| 1036 | + } |
| 1037 | + |
| 1038 | + auto adjusted = msg; |
| 1039 | + adjusted["content"] = content; |
| 1040 | + adjusted.erase("reasoning_content"); |
| 1041 | + adjusted_messages.push_back(adjusted); |
| 1042 | + } |
| 1043 | + |
| 1044 | + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); |
| 1045 | + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; |
| 1046 | + auto include_grammar = true; |
| 1047 | + |
| 1048 | + data.prompt = apply(tmpl, inputs, /* messages_override = */ adjusted_messages); |
| 1049 | + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; |
| 1050 | + data.preserved_tokens = { |
| 1051 | + "[THINK]", |
| 1052 | + "[/THINK]", |
| 1053 | + "[TOOL_CALLS]", |
| 1054 | + "[ARGS]", |
| 1055 | + }; |
| 1056 | + |
| 1057 | + auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { |
| 1058 | + auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); |
| 1059 | + |
| 1060 | + // Response format parser |
| 1061 | + if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { |
| 1062 | + // Ministral wants to emit json surrounded by code fences |
| 1063 | + return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```"; |
| 1064 | + } |
| 1065 | + |
| 1066 | + // Tool call parser |
| 1067 | + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { |
| 1068 | + auto tool_choice = p.choice(); |
| 1069 | + foreach_function(inputs.tools, [&](const json & tool) { |
| 1070 | + const auto & function = tool.at("function"); |
| 1071 | + std::string name = function.at("name"); |
| 1072 | + const auto & schema = function.at("parameters"); |
| 1073 | + |
| 1074 | + tool_choice |= p.rule("tool-" + name, |
| 1075 | + p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") |
| 1076 | + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) |
| 1077 | + ); |
| 1078 | + }); |
| 1079 | + |
| 1080 | + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; |
| 1081 | + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; |
| 1082 | + auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); |
| 1083 | + |
| 1084 | + return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; |
| 1085 | + } |
| 1086 | + |
| 1087 | + // Content only parser |
| 1088 | + include_grammar = false; |
| 1089 | + return reasoning << p.content(p.rest()); |
| 1090 | + }); |
| 1091 | + |
| 1092 | + data.parser = parser.save(); |
| 1093 | + |
| 1094 | + if (include_grammar) { |
| 1095 | + data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; |
| 1096 | + |
| 1097 | + data.grammar = build_grammar([&](const common_grammar_builder & builder) { |
| 1098 | + foreach_function(inputs.tools, [&](const json & tool) { |
| 1099 | + const auto & function = tool.at("function"); |
| 1100 | + auto schema = function.at("parameters"); |
| 1101 | + builder.resolve_refs(schema); |
| 1102 | + }); |
| 1103 | + parser.build_grammar(builder, data.grammar_lazy); |
| 1104 | + }); |
| 1105 | + |
| 1106 | + data.grammar_triggers = { |
| 1107 | + {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"} |
| 1108 | + }; |
| 1109 | + } |
| 1110 | + |
| 1111 | + return data; |
| 1112 | +} |
| 1113 | + |
990 | 1114 | static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { |
991 | 1115 | common_chat_params data; |
992 | 1116 | data.prompt = apply(tmpl, inputs); |
@@ -2341,6 +2465,7 @@ static common_chat_params common_chat_templates_apply_jinja( |
2341 | 2465 | params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); |
2342 | 2466 | params.add_generation_prompt = inputs.add_generation_prompt; |
2343 | 2467 | params.tool_choice = inputs.tool_choice; |
| 2468 | + params.reasoning_format = inputs.reasoning_format; |
2344 | 2469 | params.enable_thinking = inputs.enable_thinking; |
2345 | 2470 | params.grammar = inputs.grammar; |
2346 | 2471 | params.now = inputs.now; |
@@ -2504,6 +2629,13 @@ static common_chat_params common_chat_templates_apply_jinja( |
2504 | 2629 | return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); |
2505 | 2630 | } |
2506 | 2631 |
|
| 2632 | + // Ministral/Mistral Large 3 |
| 2633 | + if (src.find("[SYSTEM_PROMPT]") != std::string::npos && |
| 2634 | + src.find("[TOOL_CALLS]") != std::string::npos && |
| 2635 | + src.find("[ARGS]") != std::string::npos) { |
| 2636 | + return common_chat_params_init_ministral_3(tmpl, params); |
| 2637 | + } |
| 2638 | + |
2507 | 2639 | if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) { |
2508 | 2640 | return common_chat_params_init_magistral(tmpl, params); |
2509 | 2641 | } |
|
0 commit comments