|
9 | 9 | #include <minja/chat-template.hpp> |
10 | 10 | #include <minja/minja.hpp> |
11 | 11 |
|
| 12 | +#include <algorithm> |
12 | 13 | #include <cstdio> |
| 14 | +#include <cctype> |
13 | 15 | #include <exception> |
| 16 | +#include <functional> |
14 | 17 | #include <iostream> |
15 | 18 | #include <optional> |
16 | 19 | #include <stdexcept> |
@@ -643,6 +646,7 @@ const char * common_chat_format_name(common_chat_format format) { |
643 | 646 | case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; |
644 | 647 | case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; |
645 | 648 | case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; |
| 649 | + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; |
646 | 650 | case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5"; |
647 | 651 | default: |
648 | 652 | throw std::runtime_error("Unknown chat format"); |
@@ -990,6 +994,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat |
990 | 994 | return data; |
991 | 995 | } |
992 | 996 |
|
| 997 | + |
| 998 | +// Case-insensitive find |
| 999 | +static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) { |
| 1000 | + auto it = std::search( |
| 1001 | + haystack.begin() + pos, haystack.end(), |
| 1002 | + needle.begin(), needle.end(), |
| 1003 | + [](char a, char b) { return std::tolower(a) == std::tolower(b); } |
| 1004 | + ); |
| 1005 | + return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it); |
| 1006 | +} |
| 1007 | + |
| 1008 | +static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) { |
| 1009 | + common_chat_params data; |
| 1010 | + const auto is_json_schema_provided = !inputs.json_schema.is_null(); |
| 1011 | + const auto is_grammar_provided = !inputs.grammar.empty(); |
| 1012 | + const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty(); |
| 1013 | + |
| 1014 | + // the logic requires potentially modifying the messages |
| 1015 | + auto tweaked_messages = inputs.messages; |
| 1016 | + |
| 1017 | + auto replace_json_schema_marker = [](json & messages) -> bool { |
| 1018 | + static std::string marker1 = "force json schema.\n"; |
| 1019 | + static std::string marker2 = "force json schema."; |
| 1020 | + |
| 1021 | + if (messages.empty() || messages.at(0).at("role") != "system") { |
| 1022 | + return false; |
| 1023 | + } |
| 1024 | + |
| 1025 | + std::string content = messages.at(0).at("content"); |
| 1026 | + |
| 1027 | + for (const auto & marker : {marker1, marker2}) { |
| 1028 | + const auto pos = ifind_string(content, marker); |
| 1029 | + if (pos != std::string::npos) { |
| 1030 | + content.replace(pos, marker.length(), ""); |
| 1031 | + // inject modified content back into the messages |
| 1032 | + messages.at(0).at("content") = content; |
| 1033 | + return true; |
| 1034 | + } |
| 1035 | + } |
| 1036 | + |
| 1037 | + return false; |
| 1038 | + }; |
| 1039 | + |
| 1040 | + // Lfm2 model does not natively work with json, but can generally understand the tools structure |
| 1041 | + // |
| 1042 | + // Example of the pytorch dialog structure: |
| 1043 | + // <|startoftext|><|im_start|>system |
| 1044 | + // List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|> |
| 1045 | + // <|im_start|>user |
| 1046 | + // What is the current status of candidate ID 12345?<|im_end|> |
| 1047 | + // <|im_start|>assistant |
| 1048 | + // <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|> |
| 1049 | + // <|im_start|>tool |
| 1050 | + // <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|> |
| 1051 | + // <|im_start|>assistant |
| 1052 | + // The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|> |
| 1053 | + // |
| 1054 | + // For the llama server compatibility with json tools semantic, |
| 1055 | + // the client can add "Follow json schema." line into the system message prompt to force the json output. |
| 1056 | + // |
| 1057 | + if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) { |
| 1058 | + // server/utils.hpp prohibits that branch for the custom grammar anyways |
| 1059 | + throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar"); |
| 1060 | + } else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) { |
| 1061 | + LOG_INF("%s: Using tools to build a grammar\n", __func__); |
| 1062 | + |
| 1063 | + data.grammar = build_grammar([&](const common_grammar_builder & builder) { |
| 1064 | + auto schemas = json::array(); |
| 1065 | + foreach_function(inputs.tools, [&](const json & tool) { |
| 1066 | + const auto & function = tool.at("function"); |
| 1067 | + schemas.push_back({ |
| 1068 | + {"type", "object"}, |
| 1069 | + {"properties", { |
| 1070 | + {"name", { |
| 1071 | + {"type", "string"}, |
| 1072 | + {"const", function.at("name")}, |
| 1073 | + }}, |
| 1074 | + {"arguments", function.at("parameters")}, |
| 1075 | + }}, |
| 1076 | + {"required", json::array({"name", "arguments", "id"})}, |
| 1077 | + }); |
| 1078 | + }); |
| 1079 | + auto schema = json { |
| 1080 | + {"type", "array"}, |
| 1081 | + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, |
| 1082 | + {"minItems", 1}, |
| 1083 | + }; |
| 1084 | + if (!inputs.parallel_tool_calls) { |
| 1085 | + schema["maxItems"] = 1; |
| 1086 | + } |
| 1087 | + |
| 1088 | + builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\""); |
| 1089 | + }); |
| 1090 | + // model has no concept of tool selection mode choice, |
| 1091 | + // if the system prompt rendered correctly it will produce a tool call |
| 1092 | + // the grammar goes inside the tool call body |
| 1093 | + data.grammar_lazy = true; |
| 1094 | + data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}}; |
| 1095 | + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; |
| 1096 | + data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS; |
| 1097 | + } else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) { |
| 1098 | + LOG_INF("%s: Using tools without json schema or grammar\n", __func__); |
| 1099 | + // output those tokens |
| 1100 | + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; |
| 1101 | + } else if (is_json_schema_provided) { |
| 1102 | + LOG_INF("%s: Using provided json schema to build a grammar\n", __func__); |
| 1103 | + data.grammar = json_schema_to_grammar(inputs.json_schema); |
| 1104 | + } else if (is_grammar_provided) { |
| 1105 | + LOG_INF("%s: Using provided grammar\n", __func__); |
| 1106 | + data.grammar = inputs.grammar; |
| 1107 | + } else { |
| 1108 | + LOG_INF("%s: Using content relying on the template\n", __func__); |
| 1109 | + } |
| 1110 | + |
| 1111 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); |
| 1112 | + LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str()); |
| 1113 | + |
| 1114 | + return data; |
| 1115 | +} |
| 1116 | + |
993 | 1117 | static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { |
994 | 1118 | common_chat_params data; |
995 | 1119 | data.prompt = apply(tmpl, inputs); |
@@ -2891,6 +3015,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) { |
2891 | 3015 | builder.add_content(builder.consume_rest()); |
2892 | 3016 | } |
2893 | 3017 |
|
| 3018 | + |
| 3019 | +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { |
| 3020 | + if (!builder.syntax().parse_tool_calls) { |
| 3021 | + builder.add_content(builder.consume_rest()); |
| 3022 | + return; |
| 3023 | + } |
| 3024 | + |
| 3025 | + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> |
| 3026 | + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); |
| 3027 | + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); |
| 3028 | + |
| 3029 | + // Loop through all tool calls |
| 3030 | + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { |
| 3031 | + builder.move_to(res->groups[0].end); |
| 3032 | + |
| 3033 | + // Parse JSON array format: [{"name": "...", "arguments": {...}}] |
| 3034 | + auto tool_calls_data = builder.consume_json(); |
| 3035 | + |
| 3036 | + // Consume end marker |
| 3037 | + builder.consume_spaces(); |
| 3038 | + if (!builder.try_consume_regex(tool_call_end_regex)) { |
| 3039 | + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); |
| 3040 | + } |
| 3041 | + |
| 3042 | + // Process each tool call in the array |
| 3043 | + if (tool_calls_data.json.is_array()) { |
| 3044 | + for (const auto & tool_call : tool_calls_data.json) { |
| 3045 | + if (!tool_call.is_object()) { |
| 3046 | + throw common_chat_msg_partial_exception("Tool call must be an object"); |
| 3047 | + } |
| 3048 | + |
| 3049 | + if (!tool_call.contains("name")) { |
| 3050 | + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); |
| 3051 | + } |
| 3052 | + |
| 3053 | + std::string function_name = tool_call.at("name"); |
| 3054 | + std::string arguments = "{}"; |
| 3055 | + |
| 3056 | + if (tool_call.contains("arguments")) { |
| 3057 | + if (tool_call.at("arguments").is_object()) { |
| 3058 | + arguments = tool_call.at("arguments").dump(); |
| 3059 | + } else if (tool_call.at("arguments").is_string()) { |
| 3060 | + arguments = tool_call.at("arguments"); |
| 3061 | + } |
| 3062 | + } |
| 3063 | + |
| 3064 | + if (!builder.add_tool_call(function_name, "", arguments)) { |
| 3065 | + throw common_chat_msg_partial_exception("Incomplete tool call"); |
| 3066 | + } |
| 3067 | + } |
| 3068 | + } else { |
| 3069 | + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); |
| 3070 | + } |
| 3071 | + |
| 3072 | + // Consume any trailing whitespace after this tool call |
| 3073 | + builder.consume_spaces(); |
| 3074 | + } |
| 3075 | + |
| 3076 | + // Consume any remaining content after all tool calls |
| 3077 | + auto remaining = builder.consume_rest(); |
| 3078 | + if (!string_strip(remaining).empty()) { |
| 3079 | + builder.add_content(remaining); |
| 3080 | + } |
| 3081 | +} |
| 3082 | + |
2894 | 3083 | static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { |
2895 | 3084 | // Parse thinking tags first - this handles the main reasoning content |
2896 | 3085 | builder.try_parse_reasoning("<seed:think>", "</seed:think>"); |
@@ -3145,6 +3334,12 @@ static common_chat_params common_chat_templates_apply_jinja( |
3145 | 3334 | return common_chat_params_init_apertus(tmpl, params); |
3146 | 3335 | } |
3147 | 3336 |
|
| 3337 | + // LFM2 (w/ tools) |
| 3338 | + if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos && |
| 3339 | + src.find("]<|tool_list_end|>") != std::string::npos) { |
| 3340 | + return common_chat_params_init_lfm2(tmpl, params); |
| 3341 | + } |
| 3342 | + |
3148 | 3343 | // Use generic handler when mixing tools + JSON schema. |
3149 | 3344 | // TODO: support that mix in handlers below. |
3150 | 3345 | if ((params.tools.is_array() && params.json_schema.is_object())) { |
@@ -3323,6 +3518,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { |
3323 | 3518 | case COMMON_CHAT_FORMAT_APERTUS: |
3324 | 3519 | common_chat_parse_apertus(builder); |
3325 | 3520 | break; |
| 3521 | + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: |
| 3522 | + common_chat_parse_lfm2(builder); |
| 3523 | + break; |
3326 | 3524 | case COMMON_CHAT_FORMAT_GLM_4_5: |
3327 | 3525 | common_chat_parse_glm_4_5(builder); |
3328 | 3526 | break; |
|
0 commit comments