Skip to content

Commit c053e18

Browse files
authored
chat: Add LFM2 tool handling (#16763)
* Add LFM2 tool handling * fmt * Apply suggestion from @ykhrustalev
1 parent e1ab084 commit c053e18

File tree

4 files changed

+385
-0
lines changed

4 files changed

+385
-0
lines changed

common/chat.cpp

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#include <minja/chat-template.hpp>
1010
#include <minja/minja.hpp>
1111

12+
#include <algorithm>
1213
#include <cstdio>
14+
#include <cctype>
1315
#include <exception>
16+
#include <functional>
1417
#include <iostream>
1518
#include <optional>
1619
#include <stdexcept>
@@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) {
640643
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
641644
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
642645
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
646+
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
643647
default:
644648
throw std::runtime_error("Unknown chat format");
645649
}
@@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
986990
return data;
987991
}
988992

993+
994+
// Case-insensitive find
995+
static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) {
996+
auto it = std::search(
997+
haystack.begin() + pos, haystack.end(),
998+
needle.begin(), needle.end(),
999+
[](char a, char b) { return std::tolower(a) == std::tolower(b); }
1000+
);
1001+
return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it);
1002+
}
1003+
1004+
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) {
1005+
common_chat_params data;
1006+
const auto is_json_schema_provided = !inputs.json_schema.is_null();
1007+
const auto is_grammar_provided = !inputs.grammar.empty();
1008+
const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty();
1009+
1010+
// the logic requires potentially modifying the messages
1011+
auto tweaked_messages = inputs.messages;
1012+
1013+
auto replace_json_schema_marker = [](json & messages) -> bool {
1014+
static std::string marker1 = "force json schema.\n";
1015+
static std::string marker2 = "force json schema.";
1016+
1017+
if (messages.empty() || messages.at(0).at("role") != "system") {
1018+
return false;
1019+
}
1020+
1021+
std::string content = messages.at(0).at("content");
1022+
1023+
for (const auto & marker : {marker1, marker2}) {
1024+
const auto pos = ifind_string(content, marker);
1025+
if (pos != std::string::npos) {
1026+
content.replace(pos, marker.length(), "");
1027+
// inject modified content back into the messages
1028+
messages.at(0).at("content") = content;
1029+
return true;
1030+
}
1031+
}
1032+
1033+
return false;
1034+
};
1035+
1036+
// Lfm2 model does not natively work with json, but can generally understand the tools structure
1037+
//
1038+
// Example of the pytorch dialog structure:
1039+
// <|startoftext|><|im_start|>system
1040+
// List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|>
1041+
// <|im_start|>user
1042+
// What is the current status of candidate ID 12345?<|im_end|>
1043+
// <|im_start|>assistant
1044+
// <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|>
1045+
// <|im_start|>tool
1046+
// <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|>
1047+
// <|im_start|>assistant
1048+
// The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|>
1049+
//
1050+
// For the llama server compatibility with json tools semantic,
1051+
// the client can add "Follow json schema." line into the system message prompt to force the json output.
1052+
//
1053+
if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) {
1054+
// server/utils.hpp prohibits that branch for the custom grammar anyways
1055+
throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar");
1056+
} else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) {
1057+
LOG_INF("%s: Using tools to build a grammar\n", __func__);
1058+
1059+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1060+
auto schemas = json::array();
1061+
foreach_function(inputs.tools, [&](const json & tool) {
1062+
const auto & function = tool.at("function");
1063+
schemas.push_back({
1064+
{"type", "object"},
1065+
{"properties", {
1066+
{"name", {
1067+
{"type", "string"},
1068+
{"const", function.at("name")},
1069+
}},
1070+
{"arguments", function.at("parameters")},
1071+
}},
1072+
{"required", json::array({"name", "arguments", "id"})},
1073+
});
1074+
});
1075+
auto schema = json {
1076+
{"type", "array"},
1077+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
1078+
{"minItems", 1},
1079+
};
1080+
if (!inputs.parallel_tool_calls) {
1081+
schema["maxItems"] = 1;
1082+
}
1083+
1084+
builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\"");
1085+
});
1086+
// model has no concept of tool selection mode choice,
1087+
// if the system prompt rendered correctly it will produce a tool call
1088+
// the grammar goes inside the tool call body
1089+
data.grammar_lazy = true;
1090+
data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}};
1091+
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
1092+
data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS;
1093+
} else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) {
1094+
LOG_INF("%s: Using tools without json schema or grammar\n", __func__);
1095+
// output those tokens
1096+
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
1097+
} else if (is_json_schema_provided) {
1098+
LOG_INF("%s: Using provided json schema to build a grammar\n", __func__);
1099+
data.grammar = json_schema_to_grammar(inputs.json_schema);
1100+
} else if (is_grammar_provided) {
1101+
LOG_INF("%s: Using provided grammar\n", __func__);
1102+
data.grammar = inputs.grammar;
1103+
} else {
1104+
LOG_INF("%s: Using content relying on the template\n", __func__);
1105+
}
1106+
1107+
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
1108+
LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str());
1109+
1110+
return data;
1111+
}
1112+
9891113
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
9901114
common_chat_params data;
9911115
data.prompt = apply(tmpl, inputs);
@@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
24992623
builder.add_content(builder.consume_rest());
25002624
}
25012625

2626+
2627+
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
2628+
if (!builder.syntax().parse_tool_calls) {
2629+
builder.add_content(builder.consume_rest());
2630+
return;
2631+
}
2632+
2633+
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
2634+
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
2635+
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
2636+
2637+
// Loop through all tool calls
2638+
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
2639+
builder.move_to(res->groups[0].end);
2640+
2641+
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
2642+
auto tool_calls_data = builder.consume_json();
2643+
2644+
// Consume end marker
2645+
builder.consume_spaces();
2646+
if (!builder.try_consume_regex(tool_call_end_regex)) {
2647+
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
2648+
}
2649+
2650+
// Process each tool call in the array
2651+
if (tool_calls_data.json.is_array()) {
2652+
for (const auto & tool_call : tool_calls_data.json) {
2653+
if (!tool_call.is_object()) {
2654+
throw common_chat_msg_partial_exception("Tool call must be an object");
2655+
}
2656+
2657+
if (!tool_call.contains("name")) {
2658+
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
2659+
}
2660+
2661+
std::string function_name = tool_call.at("name");
2662+
std::string arguments = "{}";
2663+
2664+
if (tool_call.contains("arguments")) {
2665+
if (tool_call.at("arguments").is_object()) {
2666+
arguments = tool_call.at("arguments").dump();
2667+
} else if (tool_call.at("arguments").is_string()) {
2668+
arguments = tool_call.at("arguments");
2669+
}
2670+
}
2671+
2672+
if (!builder.add_tool_call(function_name, "", arguments)) {
2673+
throw common_chat_msg_partial_exception("Incomplete tool call");
2674+
}
2675+
}
2676+
} else {
2677+
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
2678+
}
2679+
2680+
// Consume any trailing whitespace after this tool call
2681+
builder.consume_spaces();
2682+
}
2683+
2684+
// Consume any remaining content after all tool calls
2685+
auto remaining = builder.consume_rest();
2686+
if (!string_strip(remaining).empty()) {
2687+
builder.add_content(remaining);
2688+
}
2689+
}
2690+
25022691
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
25032692
// Parse thinking tags first - this handles the main reasoning content
25042693
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
@@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja(
27482937
return common_chat_params_init_apertus(tmpl, params);
27492938
}
27502939

2940+
// LFM2 (w/ tools)
2941+
if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos &&
2942+
src.find("]<|tool_list_end|>") != std::string::npos) {
2943+
return common_chat_params_init_lfm2(tmpl, params);
2944+
}
2945+
27512946
// Use generic handler when mixing tools + JSON schema.
27522947
// TODO: support that mix in handlers below.
27532948
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
29263121
case COMMON_CHAT_FORMAT_APERTUS:
29273122
common_chat_parse_apertus(builder);
29283123
break;
3124+
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
3125+
common_chat_parse_lfm2(builder);
3126+
break;
29293127
default:
29303128
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
29313129
}

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ enum common_chat_format {
116116
COMMON_CHAT_FORMAT_SEED_OSS,
117117
COMMON_CHAT_FORMAT_NEMOTRON_V2,
118118
COMMON_CHAT_FORMAT_APERTUS,
119+
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
119120

120121
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
121122
};
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{{- bos_token -}}
2+
{%- set system_prompt = "" -%}
3+
{%- set ns = namespace(system_prompt="") -%}
4+
{%- if messages[0]["role"] == "system" -%}
5+
{%- set ns.system_prompt = messages[0]["content"] -%}
6+
{%- set messages = messages[1:] -%}
7+
{%- endif -%}
8+
{%- if tools -%}
9+
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
10+
{%- for tool in tools -%}
11+
{%- if tool is not string -%}
12+
{%- set tool = tool | tojson -%}
13+
{%- endif -%}
14+
{%- set ns.system_prompt = ns.system_prompt + tool -%}
15+
{%- if not loop.last -%}
16+
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
17+
{%- endif -%}
18+
{%- endfor -%}
19+
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
20+
{%- endif -%}
21+
{%- if ns.system_prompt -%}
22+
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
23+
{%- endif -%}
24+
{%- for message in messages -%}
25+
{{- "<|im_start|>" + message["role"] + "\n" -}}
26+
{%- set content = message["content"] -%}
27+
{%- if content is not string -%}
28+
{%- set content = content | tojson -%}
29+
{%- endif -%}
30+
{%- if message["role"] == "tool" -%}
31+
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
32+
{%- endif -%}
33+
{{- content + "<|im_end|>\n" -}}
34+
{%- endfor -%}
35+
{%- if add_generation_prompt -%}
36+
{{- "<|im_start|>assistant\n" -}}
37+
{%- endif -%}

0 commit comments

Comments
 (0)