Skip to content

Commit 137631e

Browse files
dkalinowskiprzepeck
authored andcommitted
Mistral tool calling unary (#3567)
CVS-171565
1 parent 05a50cd commit 137631e

File tree

11 files changed

+348
-6
lines changed

11 files changed

+348
-6
lines changed

.dockerignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
out
55
demos/continuous_batching
66
demos/embeddings
7+
demos/common/export_models/models

demos/common/export_models/export_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def add_common_arguments(parser):
5252
'Not effective if target device is not NPU', dest='max_prompt_len')
5353
parser_text.add_argument('--prompt_lookup_decoding', action='store_true', help='Set pipeline to use prompt lookup decoding', dest='prompt_lookup_decoding')
5454
parser_text.add_argument('--reasoning_parser', choices=["qwen3"], help='Set the type of the reasoning parser for reasoning content extraction', dest='reasoning_parser')
55-
parser_text.add_argument('--tool_parser', choices=["llama3","phi4","hermes3", "qwen3"], help='Set the type of the tool parser for tool calls extraction', dest='tool_parser')
55+
parser_text.add_argument('--tool_parser', choices=["llama3","phi4","hermes3", "qwen3","mistral"], help='Set the type of the tool parser for tool calls extraction', dest='tool_parser')
5656
parser_text.add_argument('--enable_tool_guided_generation', action='store_true', help='Enables enforcing tool schema during generation. Requires setting tool_parser', dest='enable_tool_guided_generation')
5757

5858
parser_embeddings = subparsers.add_parser('embeddings', help='[deprecated] export model for embeddings endpoint with models split into separate, versioned directories')
@@ -464,15 +464,16 @@ def export_text_generation_model(model_repository_path, source_model, model_name
464464
f.write(graph_content)
465465
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))
466466

467-
if template_parameters.get("tools_model_type") is not None:
467+
if template_parameters.get("tool_parser") is not None:
468468
print("Adding tuned chat template")
469469
template_mapping = {
470470
"phi4": "tool_chat_template_phi4_mini.jinja",
471471
"llama3": "tool_chat_template_llama3.1_json.jinja",
472472
"hermes3": "tool_chat_template_hermes.jinja",
473+
"mistral": "tool_chat_template_mistral_parallel.jinja",
473474
"qwen3": None
474475
}
475-
template_name = template_mapping[task_parameters.get("tools_model_type")]
476+
template_name = template_mapping[task_parameters.get("tool_parser")]
476477
if template_name is not None:
477478
template_path = os.path.join(model_repository_path, model_name, "template.jinja")
478479
import requests

prepare_llm_models.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ QWEN3_MODEL="Qwen/Qwen3-8B"
3030
LLAMA3_MODEL="meta-llama/Llama-3.1-8B-Instruct"
3131
HERMES3_MODEL="NousResearch/Hermes-3-Llama-3.1-8B"
3232
PHI4_MODEL="microsoft/Phi-4-mini-instruct"
33+
MISTRAL_MODEL="mistralai/Mistral-7B-Instruct-v0.3"
3334

34-
MODELS=("$CB_MODEL" "$EMBEDDING_MODEL" "$RERANK_MODEL" "$VLM_MODEL" "$QWEN3_MODEL" "$LLAMA3_MODEL" "$HERMES3_MODEL" "$PHI4_MODEL" "$EMBEDDING_MODEL/ov" "$RERANK_MODEL/ov")
35+
MODELS=("$CB_MODEL" "$EMBEDDING_MODEL" "$RERANK_MODEL" "$VLM_MODEL" "$QWEN3_MODEL" "$LLAMA3_MODEL" "$HERMES3_MODEL" "$PHI4_MODEL" "$MISTRAL_MODEL" "$EMBEDDING_MODEL/ov" "$RERANK_MODEL/ov")
3536

3637
all_exist=true
3738
for model in "${MODELS[@]}"; do
@@ -126,3 +127,11 @@ else
126127
mkdir -p $1/$PHI4_MODEL
127128
convert_tokenizer $PHI4_MODEL --with_detokenizer -o $1/$PHI4_MODEL
128129
fi
130+
131+
if [ -d "$1/$MISTRAL_MODEL" ]; then
132+
echo "Models directory $1/$MISTRAL_MODEL exists. Skipping downloading models."
133+
else
134+
mkdir -p $1/$MISTRAL_MODEL
135+
convert_tokenizer $MISTRAL_MODEL --with_detokenizer -o $1/$MISTRAL_MODEL
136+
fi
137+

src/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2511,6 +2511,7 @@ cc_test(
25112511
"test/llm/output_parsers/qwen3_output_parser_test.cpp",
25122512
"test/llm/output_parsers/hermes3_output_parser_test.cpp",
25132513
"test/llm/output_parsers/phi4_output_parser_test.cpp",
2514+
"test/llm/output_parsers/mistral_output_parser_test.cpp",
25142515
"test/llm/output_parsers/partial_json_builder_test.cpp",
25152516
],
25162517
"//:disable_python" : [],

src/llm/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ cc_library(
114114
"io_processing/hermes3/tool_parser.hpp",
115115
"io_processing/llama3/tool_parser.hpp",
116116
"io_processing/phi4/tool_parser.hpp",
117+
"io_processing/mistral/tool_parser.hpp",
117118
"io_processing/qwen3/reasoning_parser.hpp",
118119
"io_processing/output_parser.hpp",
119120
"io_processing/partial_json_builder.hpp",
@@ -122,6 +123,7 @@ cc_library(
122123
"io_processing/hermes3/tool_parser.cpp",
123124
"io_processing/llama3/tool_parser.cpp",
124125
"io_processing/phi4/tool_parser.cpp",
126+
"io_processing/mistral/tool_parser.cpp",
125127
"io_processing/qwen3/reasoning_parser.cpp",
126128
"io_processing/output_parser.cpp",
127129
"io_processing/partial_json_builder.cpp",
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//*****************************************************************************
2+
// Copyright 2025 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
17+
#include <openvino/genai/tokenizer.hpp>
18+
#include <string>
19+
#include <vector>
20+
#include <regex>
21+
22+
#pragma warning(push)
23+
#pragma warning(disable : 6313)
24+
#include <rapidjson/document.h>
25+
#include <rapidjson/stringbuffer.h>
26+
#include <rapidjson/writer.h>
27+
#pragma warning(pop)
28+
29+
#include "../../../logging.hpp"
30+
#include "tool_parser.hpp"
31+
#include "../utils.hpp"
32+
33+
namespace ovms {
34+
35+
void MistralToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) {
36+
std::vector<std::string> tools;
37+
38+
if (parsedOutput.content.empty() || generatedTokens.size() <= 0) {
39+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No content to parse for tool calls");
40+
return;
41+
}
42+
43+
if (generatedTokens[0] != this->botTokenId) {
44+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to parse functools content or extract tools array");
45+
return;
46+
}
47+
48+
rapidjson::Document toolsDoc;
49+
toolsDoc.Parse(parsedOutput.content.c_str());
50+
51+
if (!toolsDoc.HasParseError() && toolsDoc.IsArray()) {
52+
for (auto& toolVal : toolsDoc.GetArray()) {
53+
if (!toolVal.IsObject()) {
54+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call is not a valid JSON object");
55+
continue;
56+
}
57+
ToolCall toolCall;
58+
if (toolVal.HasMember("name") && toolVal["name"].IsString()) {
59+
toolCall.name = toolVal["name"].GetString();
60+
} else {
61+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid name field");
62+
continue;
63+
}
64+
65+
if (toolVal.HasMember("arguments") && toolVal["arguments"].IsObject()) {
66+
rapidjson::StringBuffer sb;
67+
rapidjson::Writer<rapidjson::StringBuffer> toolWriter(sb);
68+
toolVal["arguments"].Accept(toolWriter);
69+
toolCall.arguments = sb.GetString();
70+
} else {
71+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid parameters object");
72+
continue;
73+
}
74+
toolCall.id = generateRandomId(); // Generate a random ID for the tool call
75+
parsedOutput.toolCalls.push_back(toolCall);
76+
}
77+
parsedOutput.content.clear();
78+
} else {
79+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to parse functools content or extract tools array");
80+
}
81+
}
82+
83+
std::optional<rapidjson::Document> MistralToolParser::parseChunk(const std::string& chunk) {
84+
// Not implemented
85+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "MistralToolParser::parseChunk is not implemented");
86+
return std::nullopt;
87+
}
88+
} // namespace ovms
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//*****************************************************************************
2+
// Copyright 2025 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
#pragma once
17+
18+
#include <openvino/genai/tokenizer.hpp>
19+
#include <string>
20+
#include <optional>
21+
#include <vector>
22+
23+
#pragma warning(push)
24+
#pragma warning(disable : 6313)
25+
#include <rapidjson/document.h>
26+
#include <rapidjson/stringbuffer.h>
27+
#include <rapidjson/writer.h>
28+
#pragma warning(pop)
29+
30+
#include "../base_output_parser.hpp"
31+
32+
namespace ovms {
33+
class MistralToolParser : public BaseOutputParser {
34+
const int64_t botTokenId = 5; // [TOOL_CALLS]
35+
36+
public:
37+
MistralToolParser() = delete;
38+
explicit MistralToolParser(ov::genai::Tokenizer& tokenizer) :
39+
BaseOutputParser(tokenizer) {}
40+
41+
void parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) override;
42+
std::optional<rapidjson::Document> parseChunk(const std::string& chunk) override;
43+
const std::string& getParsingStartTag() const override {
44+
static const std::string toolCallStartTag = "[TOOL_CALLS]";
45+
return toolCallStartTag;
46+
}
47+
// Tools calls are expected to be the last part of the content, so we do not specify an end tag.
48+
const std::string& getParsingEndTag() const override {
49+
static const std::string toolCallEndTag = "";
50+
return toolCallEndTag;
51+
}
52+
};
53+
} // namespace ovms

src/llm/io_processing/output_parser.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llama3/tool_parser.hpp"
2020
#include "hermes3/tool_parser.hpp"
2121
#include "phi4/tool_parser.hpp"
22+
#include "mistral/tool_parser.hpp"
2223
#include "qwen3/reasoning_parser.hpp"
2324

2425
namespace ovms {
@@ -46,6 +47,8 @@ OutputParser::OutputParser(ov::genai::Tokenizer& tokenizer, const std::string to
4647
toolParser = std::make_unique<Hermes3ToolParser>(tokenizer);
4748
} else if (toolParserName == "phi4") {
4849
toolParser = std::make_unique<Phi4ToolParser>(tokenizer);
50+
} else if (toolParserName == "mistral") {
51+
toolParser = std::make_unique<MistralToolParser>(tokenizer);
4952
} else if (!toolParserName.empty()) {
5053
throw std::runtime_error("Unsupported tool parser: " + toolParserName);
5154
}

src/llm/servable_initializer.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
6060
global json
6161
import json
6262
from pathlib import Path
63+
global datetime
64+
import datetime
6365
6466
global contextmanager
6567
from contextlib import contextmanager
@@ -73,6 +75,10 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
7375
def raise_exception(message):
7476
raise jinja2.exceptions.TemplateError(message)
7577
78+
# Appears in some of mistral chat templates
79+
def strftime_now(format):
80+
return datetime.datetime.now().strftime(format)
81+
7682
# Following the logic from:
7783
# https://github.com/huggingface/transformers/blob/7188e2e28c6d663284634732564143b820a03f8b/src/transformers/utils/chat_template_utils.py#L398
7884
class AssistantTracker(Extension):
@@ -135,7 +141,8 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
135141
template_loader = jinja2.FileSystemLoader(searchpath=templates_directory)
136142
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols], loader=template_loader)
137143
jinja_env.policies["json.dumps_kwargs"]["ensure_ascii"] = False
138-
jinja_env.globals["raise_exception"] = raise_exception
144+
jinja_env.globals["raise_exception"] = raise_exception
145+
jinja_env.globals["strftime_now"] = strftime_now
139146
if jinja_file.is_file():
140147
template = jinja_env.get_template("template.jinja")
141148

0 commit comments

Comments
 (0)