Skip to content

Commit 51d1140

Browse files
committed
Milosz
1 parent 379bc68 commit 51d1140

File tree

3 files changed

+40
-62
lines changed

3 files changed

+40
-62
lines changed

src/llm/io_processing/mistral/tool_parser.cpp

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -35,63 +35,48 @@ namespace ovms {
3535
void MistralToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) {
3636
std::vector<std::string> tools;
3737

38-
if (parsedOutput.content.empty()) {
38+
if (parsedOutput.content.empty() || generatedTokens.size() <= 0) {
3939
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No content to parse for tool calls");
4040
return;
4141
}
4242

43-
std::string decoded = tokenizer.decode(generatedTokens, {ov::genai::skip_special_tokens(false)});
44-
45-
const std::string toolsStartString = getParsingStartTag();
46-
const std::string toolsStartEnd = getParsingEndTag();
47-
48-
size_t toolsStartPos = decoded.find(toolsStartString);
49-
size_t toolsEndPos = decoded.find(toolsStartEnd);
50-
51-
if (toolsStartPos != std::string::npos && toolsEndPos != std::string::npos) {
52-
std::string remaining = decoded.substr(0, toolsStartPos) + decoded.substr(toolsEndPos + toolsStartEnd.length());
53-
54-
size_t toolsStartPos2 = remaining.find(toolsStartString);
55-
size_t toolsEndPos2 = remaining.find(toolsStartEnd);
56-
bool hasMoreSpecialTags = !(toolsStartPos2 == std::string::npos && toolsEndPos2 == std::string::npos);
57-
58-
std::string toolsString = decoded.substr(
59-
toolsStartPos + toolsStartString.length(),
60-
toolsEndPos - toolsStartPos - toolsStartString.length());
43+
if (generatedTokens[0] != this->botTokenId) {
44+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to parse functools content or extract tools array");
45+
return;
46+
}
6147

62-
rapidjson::Document toolsDoc;
63-
toolsDoc.Parse(toolsString.c_str());
48+
rapidjson::Document toolsDoc;
49+
toolsDoc.Parse(parsedOutput.content.c_str());
6450

65-
if (!toolsDoc.HasParseError() && toolsDoc.IsArray() && !hasMoreSpecialTags) {
66-
for (auto& toolVal : toolsDoc.GetArray()) {
67-
if (!toolVal.IsObject()) {
68-
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call is not a valid JSON object");
69-
continue;
70-
}
71-
ToolCall toolCall;
72-
if (toolVal.HasMember("name") && toolVal["name"].IsString()) {
73-
toolCall.name = toolVal["name"].GetString();
74-
} else {
75-
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid name field");
76-
continue;
77-
}
51+
if (!toolsDoc.HasParseError() && toolsDoc.IsArray()) {
52+
for (auto& toolVal : toolsDoc.GetArray()) {
53+
if (!toolVal.IsObject()) {
54+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call is not a valid JSON object");
55+
continue;
56+
}
57+
ToolCall toolCall;
58+
if (toolVal.HasMember("name") && toolVal["name"].IsString()) {
59+
toolCall.name = toolVal["name"].GetString();
60+
} else {
61+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid name field");
62+
continue;
63+
}
7864

79-
if (toolVal.HasMember("arguments") && toolVal["arguments"].IsObject()) {
80-
rapidjson::StringBuffer sb;
81-
rapidjson::Writer<rapidjson::StringBuffer> toolWriter(sb);
82-
toolVal["arguments"].Accept(toolWriter);
83-
toolCall.arguments = sb.GetString();
84-
} else {
85-
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid parameters object");
86-
continue;
87-
}
88-
toolCall.id = generateRandomId(); // Generate a random ID for the tool call
89-
parsedOutput.toolCalls.push_back(toolCall);
65+
if (toolVal.HasMember("arguments") && toolVal["arguments"].IsObject()) {
66+
rapidjson::StringBuffer sb;
67+
rapidjson::Writer<rapidjson::StringBuffer> toolWriter(sb);
68+
toolVal["arguments"].Accept(toolWriter);
69+
toolCall.arguments = sb.GetString();
70+
} else {
71+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Tool call does not contain valid parameters object");
72+
continue;
9073
}
91-
parsedOutput.content = remaining;
92-
} else {
93-
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to parse functools content or extract tools array");
74+
toolCall.id = generateRandomId(); // Generate a random ID for the tool call
75+
parsedOutput.toolCalls.push_back(toolCall);
9476
}
77+
parsedOutput.content.clear();
78+
} else {
79+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Failed to parse functools content or extract tools array");
9580
}
9681
}
9782

src/llm/io_processing/mistral/tool_parser.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131

3232
namespace ovms {
3333
class MistralToolParser : public BaseOutputParser {
34-
const std::string toolCallStartTag = "[TOOL_CALLS]";
35-
const std::string toolCallEndTag = "</s>";
34+
const int64_t botTokenId = 5; // [TOOL_CALLS]
3635

3736
public:
3837
MistralToolParser() = delete;
@@ -42,10 +41,12 @@ class MistralToolParser : public BaseOutputParser {
4241
void parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) override;
4342
std::optional<rapidjson::Document> parseChunk(const std::string& chunk) override;
4443
const std::string& getParsingStartTag() const override {
44+
static const std::string toolCallStartTag = "[TOOL_CALLS]";
4545
return toolCallStartTag;
4646
}
4747
// Tools calls are expected to be the last part of the content, so we do not specify an end tag.
4848
const std::string& getParsingEndTag() const override {
49+
static const std::string toolCallEndTag = "";
4950
return toolCallEndTag;
5051
}
5152
};

src/test/llm/output_parsers/mistral_output_parser_test.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,28 +125,20 @@ TEST_F(MistralOutputParserTest, ParseToolCallOutputWithContentAndSingleToolCall)
125125
auto generatedTensor = mistralTokenizer.encode(input, ov::genai::add_special_tokens(false)).input_ids;
126126
std::vector<int64_t> generatedTokens(generatedTensor.data<int64_t>(), generatedTensor.data<int64_t>() + generatedTensor.get_size());
127127
ParsedOutput parsedOutput = outputParser->parse(generatedTokens, true);
128-
EXPECT_EQ(parsedOutput.content, "This is a content part and next will be a tool call.\n\n");
128+
EXPECT_EQ(parsedOutput.content, "This is a content part and next will be a tool call.\n\n [{\"name\": \"example_tool\", \"arguments\": {\"arg1\": \"value1\", \"arg2\": 42}}]");
129129
EXPECT_EQ(parsedOutput.reasoning, "");
130130

131-
ASSERT_EQ(parsedOutput.toolCalls.size(), 1);
132-
EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool");
133-
// Parser removes whitespaces, so we expect arguments value to be without spaces
134-
EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}");
135-
EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); // ID should be generated
131+
ASSERT_EQ(parsedOutput.toolCalls.size(), 0);
136132
}
137133
TEST_F(MistralOutputParserTest, ParseToolCallOutputWithContentOnBothSidesAndSingleToolCall) {
138134
std::string input = "This is a content part and next will be a tool call.\n\n[TOOL_CALLS][{\"name\": \"example_tool\", \"arguments\": {\"arg1\": \"value1\", \"arg2\": 42}}]</s> This is a content part after tool call.";
139135
auto generatedTensor = mistralTokenizer.encode(input, ov::genai::add_special_tokens(false)).input_ids;
140136
std::vector<int64_t> generatedTokens(generatedTensor.data<int64_t>(), generatedTensor.data<int64_t>() + generatedTensor.get_size());
141137
ParsedOutput parsedOutput = outputParser->parse(generatedTokens, true);
142-
EXPECT_EQ(parsedOutput.content, "This is a content part and next will be a tool call.\n\n This is a content part after tool call.");
138+
EXPECT_EQ(parsedOutput.content, "This is a content part and next will be a tool call.\n\n [{\"name\": \"example_tool\", \"arguments\": {\"arg1\": \"value1\", \"arg2\": 42}}] This is a content part after tool call.");
143139
EXPECT_EQ(parsedOutput.reasoning, "");
144140

145-
ASSERT_EQ(parsedOutput.toolCalls.size(), 1);
146-
EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool");
147-
// Parser removes whitespaces, so we expect arguments value to be without spaces
148-
EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}");
149-
EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); // ID should be generated
141+
ASSERT_EQ(parsedOutput.toolCalls.size(), 0);
150142
}
151143
TEST_F(MistralOutputParserTest, ParseToolCallOutputWithMultipleToolCallsReturnsContentOnly) {
152144
std::string input = "[TOOL_CALLS][{\"name\": \"tool1\", \"arguments\": {\"a\": 1}}]</s> \n\nThis is some content\n\n[TOOL_CALLS][{\"name\": \"tool2\", \"arguments\": {\"b\": 2}}]</s>";

0 commit comments

Comments
 (0)