Skip to content
Merged
42 changes: 26 additions & 16 deletions demos/common/export_models/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def add_common_arguments(parser):
'Not effective if target device is not NPU', dest='max_prompt_len')
parser_text.add_argument('--prompt_lookup_decoding', action='store_true', help='Set pipeline to use prompt lookup decoding', dest='prompt_lookup_decoding')
parser_text.add_argument('--reasoning_parser', choices=["qwen3"], help='Set the type of the reasoning parser for reasoning content extraction', dest='reasoning_parser')
parser_text.add_argument('--tool_parser', choices=["llama3","phi4","hermes3"], help='Set the type of the tool parser for tool calls extraction', dest='tool_parser')
parser_text.add_argument('--tool_parser', choices=["llama3","phi4","hermes3","mistral"], help='Set the type of the tool parser for tool calls extraction', dest='tool_parser')
parser_text.add_argument('--enable_tool_guided_generation', action='store_true', help='Enables enforcing tool schema during generation. Requires setting tool_parser', dest='enable_tool_guided_generation')

parser_embeddings = subparsers.add_parser('embeddings', help='[deprecated] export model for embeddings endpoint with models split into separate, versioned directories')
Expand Down Expand Up @@ -458,23 +458,33 @@ def export_text_generation_model(model_repository_path, source_model, model_name
f.write(graph_content)
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))

if template_parameters.get("tools_model_type") is not None:
if template_parameters.get("tool_parser") is not None:
print("Adding tuned chat template")
template_mapping = {
"phi4": "tool_chat_template_phi4_mini.jinja",
"llama3": "tool_chat_template_llama3.1_json.jinja",
"hermes3": "tool_chat_template_hermes.jinja",
"qwen3": None
}
template_name = template_mapping[task_parameters.get("tools_model_type")]
if template_name is not None:
# Custom Templates
if template_parameters.get("tool_parser") == "mistral":
template_path = os.path.join(model_repository_path, model_name, "template.jinja")
import requests
response = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.9.0/examples/" + template_name)
print(response.raise_for_status())
with open(template_path, "wb") as f:
f.write(response.content)
print(f"Downloaded tuned chat template to {template_path}")
with open(template_path, "w") as f:
# Modified from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.9.0/examples/tool_chat_template_mistral_parallel.jinja
with open(os.path.dirname(os.path.abspath(__file__)) + "/templates/tool_chat_template_mistral_parallel.jinja", "r") as template_file:
content = template_file.read()
f.write(content)
print(f"Added tuned chat template to {template_path}")
else: # VLLM templates
template_mapping = {
"phi4": "tool_chat_template_phi4_mini.jinja",
"llama3": "tool_chat_template_llama3.1_json.jinja",
"hermes3": "tool_chat_template_hermes.jinja",
"qwen3": None
}
template_name = template_mapping[task_parameters.get("tool_parser")]
if template_name is not None:
template_path = os.path.join(model_repository_path, model_name, "template.jinja")
import requests
response = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.9.0/examples/" + template_name)
print(response.raise_for_status())
with open(template_path, "wb") as f:
f.write(response.content)
print(f"Downloaded tuned chat template to {template_path}")

add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
{%- if messages[0]["role"] == "system" %}
Copy link
Collaborator

@dtrawins dtrawins Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want to maintain this template here. Check this PR #3565

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed: now using vllm's mistral_parallel.jinja which gives worse accuracy

{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- elif tools is not none %}
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a **SINGLE** JSON array of objects, where each object is a tool call, PLEASE NEST ALL OBJECTS IN SINGLE array. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
{%- if system_message is defined %}
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
{%- else %}
{%- set system_message = parallel_tool_prompt %}
{%- endif %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}

{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- endfor %}

{{- bos_token }}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS] [" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{%- if loop.last and system_message is defined %}
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
{%- else %}
{{- "[INST] " + message["content"] + "[/INST]" }}
{%- endif %}
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{{- "[TOOL_CALLS] [" }}
{%- for tool_call in tool_calls %}
{%- set out = tool_call.function|tojson %}
{{- out[:-1] }}
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
{%- endif %}
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" + eos_token }}
{%- endif %}
{%- endfor %}
{%- elif message["role"] == "assistant" %}
{{- " " + message["content"] + eos_token }}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
{%- endif %}
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}
1 change: 1 addition & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2511,6 +2511,7 @@ cc_test(
"test/llm/output_parsers/qwen3_output_parser_test.cpp",
"test/llm/output_parsers/hermes3_output_parser_test.cpp",
"test/llm/output_parsers/phi4_output_parser_test.cpp",
"test/llm/output_parsers/mistral_output_parser_test.cpp",
"test/llm/output_parsers/partial_json_builder_test.cpp",
],
"//:disable_python" : [],
Expand Down
2 changes: 2 additions & 0 deletions src/llm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ cc_library(
"io_processing/hermes3/tool_parser.hpp",
"io_processing/llama3/tool_parser.hpp",
"io_processing/phi4/tool_parser.hpp",
"io_processing/mistral/tool_parser.hpp",
"io_processing/qwen3/reasoning_parser.hpp",
"io_processing/output_parser.hpp",
"io_processing/partial_json_builder.hpp",
Expand All @@ -122,6 +123,7 @@ cc_library(
"io_processing/hermes3/tool_parser.cpp",
"io_processing/llama3/tool_parser.cpp",
"io_processing/phi4/tool_parser.cpp",
"io_processing/mistral/tool_parser.cpp",
"io_processing/qwen3/reasoning_parser.cpp",
"io_processing/output_parser.cpp",
"io_processing/partial_json_builder.cpp",
Expand Down
114 changes: 114 additions & 0 deletions src/llm/io_processing/mistral/tool_parser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//*****************************************************************************
// Copyright 2025 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <openvino/genai/tokenizer.hpp>
#include <string>
#include <vector>
#include <regex>

#pragma warning(push)
#pragma warning(disable : 6313)
#include <rapidjson/document.h>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
#pragma warning(pop)

#include "../../../logging.hpp"
#include "tool_parser.hpp"
#include "../utils.hpp"

namespace ovms {

static std::string::iterator skipToFirstNonWhitespaceCharacter(std::string::iterator it, const std::string::iterator& end) {
while (it != end && std::isspace(*it)) {
++it;
}
return it;
}

void MistralToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) {
std::vector<std::string> tools;

if (parsedOutput.content.empty()) {
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No content to parse for tool calls");
return;
}

// Mistral with vLLM template produces tool calls in the format:
// [{"name": [function name], "arguments": [function arguments as JSON]}, ...]
// So we ensure if first non white characters are [ and {
bool isToolGenerated = false;
auto begin = parsedOutput.content.begin();
begin = skipToFirstNonWhitespaceCharacter(begin, parsedOutput.content.end());
if (begin != parsedOutput.content.end() && *begin == '[') {
begin = skipToFirstNonWhitespaceCharacter(begin + 1, parsedOutput.content.end());
if (begin != parsedOutput.content.end() && *begin == '{') {
// If the content starts with '[{', it indicates that tool calls are present.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you make sure Mistral does not generate any BOT token that we might not see after decoding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BOT tokens appear when we detokenize with skip_special_tokens=false. changed parser to base on that

isToolGenerated = true;
} else {
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool calls are not generated by the model");
}
} else {
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool calls are not generated by the model");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have isToolGenerated, maybe we could drop these elses and go with

if (!isToolGenerated) {
	SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool calls are not generated by the model");
}

outside of this block?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed anymore



if (isToolGenerated) {
// Extract the content before the tools part
// parsedOutput.content = decoded.substr(0, toolsStartPos);
// Extract the tools part, assuming it's all the remaining content after "[{"
rapidjson::Document toolsDoc;
toolsDoc.Parse(parsedOutput.content.c_str());
if (!toolsDoc.HasParseError() && toolsDoc.IsArray()) {
for (auto& toolVal : toolsDoc.GetArray()) {
if (!toolVal.IsObject()) {
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call is not a valid JSON object");
continue;
}
ToolCall toolCall;
toolCall.id = generateRandomId(); // Generate a random ID for the tool call
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have it at the end before pushing to toolCalls

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (toolVal.HasMember("name") && toolVal["name"].IsString()) {
toolCall.name = toolVal["name"].GetString();
} else {
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid name field");
continue;
}

if (toolVal.HasMember("arguments") && toolVal["arguments"].IsObject()) {
rapidjson::StringBuffer sb;
rapidjson::Writer<rapidjson::StringBuffer> toolWriter(sb);
toolVal["arguments"].Accept(toolWriter);
toolCall.arguments = sb.GetString();
} else {
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid parameters object");
continue;
}
parsedOutput.toolCalls.push_back(toolCall);
}
// Remove the tools part from the content
parsedOutput.content.clear();
} else {
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to parse functools content or extract tools array");
}
}
}

std::optional<rapidjson::Document> MistralToolParser::parseChunk(const std::string& chunk) {
// Not implemented
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "MistralToolParser::parseChunk is not implemented");
return std::nullopt;
}
} // namespace ovms
51 changes: 51 additions & 0 deletions src/llm/io_processing/mistral/tool_parser.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//*****************************************************************************
// Copyright 2025 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once

#include <openvino/genai/tokenizer.hpp>
#include <string>
#include <optional>
#include <vector>

#pragma warning(push)
#pragma warning(disable : 6313)
#include <rapidjson/document.h>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
#pragma warning(pop)

#include "../base_output_parser.hpp"

namespace ovms {
class MistralToolParser : public BaseOutputParser {
public:
MistralToolParser() = delete;
explicit MistralToolParser(ov::genai::Tokenizer& tokenizer) :
BaseOutputParser(tokenizer) {}

void parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) override;
std::optional<rapidjson::Document> parseChunk(const std::string& chunk) override;
const std::string& getParsingStartTag() const override {
static const std::string parsingStartTag = "[{";
return parsingStartTag;
}
// Tools calls are expected to be the last part of the content, so we do not specify an end tag.
const std::string& getParsingEndTag() const override {
static const std::string parsingEndTag = "[{";
return parsingEndTag;
}
};
} // namespace ovms
3 changes: 3 additions & 0 deletions src/llm/io_processing/output_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llama3/tool_parser.hpp"
#include "hermes3/tool_parser.hpp"
#include "phi4/tool_parser.hpp"
#include "mistral/tool_parser.hpp"
#include "qwen3/reasoning_parser.hpp"

namespace ovms {
Expand Down Expand Up @@ -46,6 +47,8 @@ OutputParser::OutputParser(ov::genai::Tokenizer& tokenizer, const std::string to
toolParser = std::make_unique<Hermes3ToolParser>(tokenizer);
} else if (toolParserName == "phi4") {
toolParser = std::make_unique<Phi4ToolParser>(tokenizer);
} else if (toolParserName == "mistral") {
toolParser = std::make_unique<MistralToolParser>(tokenizer);
} else if (!toolParserName.empty()) {
throw std::runtime_error("Unsupported tool parser: " + toolParserName);
}
Expand Down
11 changes: 9 additions & 2 deletions src/llm/servable_initializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,20 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
global json
import json
from pathlib import Path
global datetime
import datetime

global jinja2
import jinja2
from jinja2.sandbox import ImmutableSandboxedEnvironment

def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)


# Appears in some of mistral chat templates
def strftime_now(format):
return datetime.datetime.now().strftime(format)

# Default chat template accepts only single message and outputs only it's 'content'
# effectively turning it into a regular prompt.
default_chat_template = "{% if messages|length != 1 %} {{ raise_exception('This servable accepts only single message requests') }}{% endif %}{{ messages[0]['content'] }}"
Expand All @@ -85,7 +91,8 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
template_loader = jinja2.FileSystemLoader(searchpath=templates_directory)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, loader=template_loader)
jinja_env.policies["json.dumps_kwargs"]["ensure_ascii"] = False
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["strftime_now"] = strftime_now
if jinja_file.is_file():
template = jinja_env.get_template("template.jinja")

Expand Down
Loading