-
Notifications
You must be signed in to change notification settings - Fork 218
Mistral tool calling unary #3567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
b799c93
c033f88
17b361f
5eed9ca
9660fc5
eb9d1e6
a2249e4
fa66a41
ad2cb60
cf47ff2
5f02112
cca406b
94597a1
379bc68
5b8e24a
360c5f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
{%- if messages[0]["role"] == "system" %} | ||
{%- set system_message = messages[0]["content"] %} | ||
{%- set loop_messages = messages[1:] %} | ||
{%- else %} | ||
{%- set loop_messages = messages %} | ||
{%- endif %} | ||
{%- if not tools is defined %} | ||
{%- set tools = none %} | ||
{%- elif tools is not none %} | ||
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a **SINGLE** JSON array of objects, where each object is a tool call, PLEASE NEST ALL OBJECTS IN SINGLE array. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} | ||
{%- if system_message is defined %} | ||
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} | ||
{%- else %} | ||
{%- set system_message = parallel_tool_prompt %} | ||
{%- endif %} | ||
{%- endif %} | ||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} | ||
|
||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} | ||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} | ||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} | ||
{%- endif %} | ||
{%- endfor %} | ||
|
||
{{- bos_token }} | ||
{%- for message in loop_messages %} | ||
{%- if message["role"] == "user" %} | ||
{%- if tools is not none and (message == user_messages[-1]) %} | ||
{{- "[AVAILABLE_TOOLS] [" }} | ||
{%- for tool in tools %} | ||
{%- set tool = tool.function %} | ||
{{- '{"type": "function", "function": {' }} | ||
{%- for key, val in tool.items() if key != "return" %} | ||
{%- if val is string %} | ||
{{- '"' + key + '": "' + val + '"' }} | ||
{%- else %} | ||
{{- '"' + key + '": ' + val|tojson }} | ||
{%- endif %} | ||
{%- if not loop.last %} | ||
{{- ", " }} | ||
{%- endif %} | ||
{%- endfor %} | ||
{{- "}}" }} | ||
{%- if not loop.last %} | ||
{{- ", " }} | ||
{%- else %} | ||
{{- "]" }} | ||
{%- endif %} | ||
{%- endfor %} | ||
{{- "[/AVAILABLE_TOOLS]" }} | ||
{%- endif %} | ||
{%- if loop.last and system_message is defined %} | ||
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} | ||
{%- else %} | ||
{{- "[INST] " + message["content"] + "[/INST]" }} | ||
{%- endif %} | ||
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} | ||
{%- if message.tool_calls is defined %} | ||
{%- set tool_calls = message.tool_calls %} | ||
{%- else %} | ||
{%- set tool_calls = message.content %} | ||
{%- endif %} | ||
{{- "[TOOL_CALLS] [" }} | ||
{%- for tool_call in tool_calls %} | ||
{%- set out = tool_call.function|tojson %} | ||
{{- out[:-1] }} | ||
{%- if not tool_call.id is defined or tool_call.id|length < 9 %} | ||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} | ||
{%- endif %} | ||
{{- ', "id": "' + tool_call.id[-9:] + '"}' }} | ||
{%- if not loop.last %} | ||
{{- ", " }} | ||
{%- else %} | ||
{{- "]" + eos_token }} | ||
{%- endif %} | ||
{%- endfor %} | ||
{%- elif message["role"] == "assistant" %} | ||
{{- " " + message["content"] + eos_token }} | ||
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %} | ||
{%- if message.content is defined and message.content.content is defined %} | ||
{%- set content = message.content.content %} | ||
{%- else %} | ||
{%- set content = message.content %} | ||
{%- endif %} | ||
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} | ||
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} | ||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} | ||
{%- endif %} | ||
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} | ||
{%- else %} | ||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} | ||
{%- endif %} | ||
{%- endfor %} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
//***************************************************************************** | ||
// Copyright 2025 Intel Corporation | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
//***************************************************************************** | ||
|
||
#include <openvino/genai/tokenizer.hpp> | ||
#include <string> | ||
#include <vector> | ||
#include <regex> | ||
|
||
#pragma warning(push) | ||
#pragma warning(disable : 6313) | ||
#include <rapidjson/document.h> | ||
#include <rapidjson/stringbuffer.h> | ||
#include <rapidjson/writer.h> | ||
#pragma warning(pop) | ||
|
||
#include "../../../logging.hpp" | ||
#include "tool_parser.hpp" | ||
#include "../utils.hpp" | ||
|
||
namespace ovms { | ||
|
||
static std::string::iterator skipToFirstNonWhitespaceCharacter(std::string::iterator it, const std::string::iterator& end) { | ||
dkalinowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
while (it != end && std::isspace(*it)) { | ||
++it; | ||
} | ||
return it; | ||
} | ||
|
||
void MistralToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) { | ||
std::vector<std::string> tools; | ||
|
||
if (parsedOutput.content.empty()) { | ||
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No content to parse for tool calls"); | ||
return; | ||
} | ||
|
||
// Mistral with vLLM template produces tool calls in the format: | ||
// [{"name": [function name], "arguments": [function arguments as JSON]}, ...] | ||
// So we ensure if first non white characters are [ and { | ||
bool isToolGenerated = false; | ||
dkalinowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto begin = parsedOutput.content.begin(); | ||
begin = skipToFirstNonWhitespaceCharacter(begin, parsedOutput.content.end()); | ||
if (begin != parsedOutput.content.end() && *begin == '[') { | ||
begin = skipToFirstNonWhitespaceCharacter(begin + 1, parsedOutput.content.end()); | ||
if (begin != parsedOutput.content.end() && *begin == '{') { | ||
// If the content starts with '[{', it indicates that tool calls are present. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you make sure Mistral does not generate any BOT token that we might not see after decoding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BOT tokens appear when we detokenize with skip_special_tokens=false. changed parser to base on that |
||
isToolGenerated = true; | ||
} else { | ||
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool calls are not generated by the model"); | ||
} | ||
} else { | ||
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool calls are not generated by the model"); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have isToolGenerated, maybe we could drop these
outside of this block? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed anymore |
||
|
||
|
||
if (isToolGenerated) { | ||
// Extract the content before the tools part | ||
// parsedOutput.content = decoded.substr(0, toolsStartPos); | ||
dkalinowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Extract the tools part, assuming it's all the remaining content after "[{" | ||
rapidjson::Document toolsDoc; | ||
toolsDoc.Parse(parsedOutput.content.c_str()); | ||
if (!toolsDoc.HasParseError() && toolsDoc.IsArray()) { | ||
for (auto& toolVal : toolsDoc.GetArray()) { | ||
if (!toolVal.IsObject()) { | ||
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call is not a valid JSON object"); | ||
continue; | ||
} | ||
ToolCall toolCall; | ||
toolCall.id = generateRandomId(); // Generate a random ID for the tool call | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can have it at the end before pushing to toolCalls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if (toolVal.HasMember("name") && toolVal["name"].IsString()) { | ||
toolCall.name = toolVal["name"].GetString(); | ||
} else { | ||
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid name field"); | ||
continue; | ||
} | ||
|
||
if (toolVal.HasMember("arguments") && toolVal["arguments"].IsObject()) { | ||
rapidjson::StringBuffer sb; | ||
rapidjson::Writer<rapidjson::StringBuffer> toolWriter(sb); | ||
toolVal["arguments"].Accept(toolWriter); | ||
toolCall.arguments = sb.GetString(); | ||
} else { | ||
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid parameters object"); | ||
continue; | ||
} | ||
parsedOutput.toolCalls.push_back(toolCall); | ||
} | ||
// Remove the tools part from the content | ||
parsedOutput.content.clear(); | ||
} else { | ||
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to parse functools content or extract tools array"); | ||
} | ||
} | ||
} | ||
|
||
std::optional<rapidjson::Document> MistralToolParser::parseChunk(const std::string& chunk) { | ||
// Not implemented | ||
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "MistralToolParser::parseChunk is not implemented"); | ||
return std::nullopt; | ||
} | ||
} // namespace ovms |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
//***************************************************************************** | ||
// Copyright 2025 Intel Corporation | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
//***************************************************************************** | ||
#pragma once | ||
|
||
#include <openvino/genai/tokenizer.hpp> | ||
#include <string> | ||
#include <optional> | ||
#include <vector> | ||
|
||
#pragma warning(push) | ||
#pragma warning(disable : 6313) | ||
#include <rapidjson/document.h> | ||
#include <rapidjson/stringbuffer.h> | ||
#include <rapidjson/writer.h> | ||
#pragma warning(pop) | ||
|
||
#include "../base_output_parser.hpp" | ||
|
||
namespace ovms { | ||
class MistralToolParser : public BaseOutputParser { | ||
public: | ||
MistralToolParser() = delete; | ||
explicit MistralToolParser(ov::genai::Tokenizer& tokenizer) : | ||
BaseOutputParser(tokenizer) {} | ||
|
||
void parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) override; | ||
std::optional<rapidjson::Document> parseChunk(const std::string& chunk) override; | ||
const std::string& getParsingStartTag() const override { | ||
static const std::string parsingStartTag = "[{"; | ||
return parsingStartTag; | ||
} | ||
// Tools calls are expected to be the last part of the content, so we do not specify an end tag. | ||
const std::string& getParsingEndTag() const override { | ||
static const std::string parsingEndTag = "[{"; | ||
return parsingEndTag; | ||
} | ||
dkalinowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}; | ||
} // namespace ovms |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't want to maintain this template here. Check this PR #3565
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed: now using vllm's mistral_parallel.jinja which gives worse accuracy