Skip to content

Commit b3a0b61

Browse files
authored
feat: add streaming and non-streaming function call support for GLM-4.5. (#303)
1 parent b05f5ef commit b3a0b61

File tree

8 files changed

+823
-13
lines changed

8 files changed

+823
-13
lines changed

xllm/api_service/chat_service_impl.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,40 @@ void set_logprobs(proto::ChatChoice* choice,
125125
}
126126

127127
struct StreamingState {
128-
std::unique_ptr<function_call::FunctionCallParser> parser;
129-
std::unordered_map<size_t, bool> has_tool_calls;
128+
std::vector<function_call::JsonTool> tools;
129+
std::string parser_format;
130+
131+
std::vector<std::unique_ptr<function_call::FunctionCallParser>> parsers;
132+
std::vector<bool> has_tool_calls;
130133

131134
StreamingState(const std::vector<function_call::JsonTool>& tools,
132-
const std::string& parser_format) {
135+
const std::string& parser_format)
136+
: tools(tools), parser_format(parser_format) {
133137
if (!tools.empty() && !parser_format.empty()) {
134-
parser = std::make_unique<function_call::FunctionCallParser>(
138+
parsers.resize(1);
139+
has_tool_calls.resize(1, false);
140+
parsers[0] = std::make_unique<function_call::FunctionCallParser>(
135141
tools, parser_format);
136142
}
137143
}
144+
145+
function_call::FunctionCallParser* get_parser_for_sequence(size_t index) {
146+
if (tools.empty() || parser_format.empty()) {
147+
return nullptr;
148+
}
149+
150+
if (index >= parsers.size()) {
151+
parsers.resize(index + 1);
152+
has_tool_calls.resize(index + 1, false);
153+
}
154+
155+
if (!parsers[index]) {
156+
parsers[index] = std::make_unique<function_call::FunctionCallParser>(
157+
tools, parser_format);
158+
}
159+
160+
return parsers[index].get();
161+
}
138162
};
139163

140164
template <typename ChatCall>
@@ -206,11 +230,12 @@ bool process_tool_call_stream(std::shared_ptr<ChatCall> call,
206230
const std::string& request_id,
207231
int64_t created_time,
208232
const std::string& model) {
209-
if (!streaming_state->parser) {
233+
auto* parser = streaming_state->get_parser_for_sequence(index);
234+
if (!parser) {
210235
return true;
211236
}
212237

213-
auto parse_result = streaming_state->parser->parse_streaming_increment(delta);
238+
auto parse_result = parser->parse_streaming_increment(delta);
214239

215240
if (!parse_result.normal_text.empty()) {
216241
if (!send_normal_text_chunk(call,
@@ -224,6 +249,9 @@ bool process_tool_call_stream(std::shared_ptr<ChatCall> call,
224249
}
225250

226251
for (const auto& call_item : parse_result.calls) {
252+
if (index >= streaming_state->has_tool_calls.size()) {
253+
streaming_state->has_tool_calls.resize(index + 1, false);
254+
}
227255
streaming_state->has_tool_calls[index] = true;
228256

229257
std::string tool_call_id;
@@ -258,11 +286,12 @@ bool check_for_unstreamed_tool_args(
258286
const std::string& request_id,
259287
int64_t created_time,
260288
const std::string& model) {
261-
if (!streaming_state->parser) {
289+
auto* parser = streaming_state->get_parser_for_sequence(index);
290+
if (!parser) {
262291
return true;
263292
}
264293

265-
auto* detector = streaming_state->parser->get_detector();
294+
auto* detector = parser->get_detector();
266295
if (!detector) {
267296
return true;
268297
}
@@ -335,7 +364,7 @@ bool send_delta_to_client_brpc(
335364
}
336365

337366
if (!seq_output.text.empty()) {
338-
if (streaming_state && streaming_state->parser) {
367+
if (streaming_state && streaming_state->get_parser_for_sequence(index)) {
339368
if (!process_tool_call_stream(call,
340369
streaming_state,
341370
index,
@@ -365,7 +394,8 @@ bool send_delta_to_client_brpc(
365394
// Handle finish reason
366395
if (seq_output.finish_reason.has_value()) {
367396
// Check for unstreamed tool args before sending finish reason
368-
if (streaming_state && streaming_state->has_tool_calls[index]) {
397+
if (streaming_state && index < streaming_state->has_tool_calls.size() &&
398+
streaming_state->has_tool_calls[index]) {
369399
if (!check_for_unstreamed_tool_args(call,
370400
streaming_state,
371401
index,
@@ -385,7 +415,8 @@ bool send_delta_to_client_brpc(
385415
choice->set_index(index);
386416
choice->mutable_delta();
387417

388-
if (streaming_state && streaming_state->has_tool_calls[index] &&
418+
if (streaming_state && index < streaming_state->has_tool_calls.size() &&
419+
streaming_state->has_tool_calls[index] &&
389420
seq_output.finish_reason.value() == "stop") {
390421
choice->set_finish_reason("tool_calls");
391422
} else {

xllm/core/common/global_flags.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ DEFINE_bool(enable_shm,
283283
DEFINE_string(tool_call_parser,
284284
"",
285285
"Specify the parser for handling tool-call interactions(e.g. "
286-
"qwen25, qwen3, kimi_k2, deepseekv3).");
286+
"qwen25, qwen3, kimi_k2, deepseekv3, glm45).");
287287

288288
// --- speculative config ---
289289

xllm/function_call/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cc_library (
1212
qwen25_detector.h
1313
kimik2_detector.h
1414
deepseekv3_detector.h
15+
glm45_detector.h
1516
function_call_parser.h
1617
function_call.h
1718
utils.h
@@ -20,6 +21,7 @@ cc_library (
2021
qwen25_detector.cpp
2122
kimik2_detector.cpp
2223
deepseekv3_detector.cpp
24+
glm45_detector.cpp
2325
function_call_parser.cpp
2426
utils.cpp
2527
DEPS
@@ -47,4 +49,5 @@ endfunction()
4749
add_detector_test(qwen25_detector_test)
4850
add_detector_test(kimik2_detector_test)
4951
add_detector_test(deepseekv3_detector_test)
52+
add_detector_test(glm45_detector_test)
5053

xllm/function_call/function_call.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include "core_types.h"
2020
#include "deepseekv3_detector.h"
2121
#include "function_call_parser.h"
22+
#include "glm45_detector.h"
2223
#include "kimik2_detector.h"
2324
#include "qwen25_detector.h"
2425

xllm/function_call/function_call_parser.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "core/util/uuid.h"
2222
#include "deepseekv3_detector.h"
23+
#include "glm45_detector.h"
2324
#include "kimik2_detector.h"
2425
#include "qwen25_detector.h"
2526
namespace xllm {
@@ -31,12 +32,12 @@ const std::unordered_map<std::string, std::string>
3132
{"qwen3", "qwen25"},
3233
{"kimi_k2", "kimi_k2"},
3334
{"deepseekv3", "deepseekv3"},
35+
{"glm45", "glm45"},
3436
// TODO
3537
// {"llama3", "llama3"},
3638
// {"mistral", "mistral"},
3739
// {"pythonic", "pythonic"},
3840
// {"qwen3_coder", "qwen3_coder"},
39-
// {"glm45", "glm45"},
4041
// {"step3", "step3"},
4142
};
4243

@@ -96,6 +97,10 @@ std::unique_ptr<BaseFormatDetector> FunctionCallParser::create_detector(
9697
return std::make_unique<DeepSeekV3Detector>();
9798
}
9899

100+
if (it->second == "glm45") {
101+
return std::make_unique<Glm45Detector>();
102+
}
103+
99104
// if (tool_call_parser == "llama3") {
100105
// return std::make_unique<Llama32Detector>();
101106
// }
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "glm45_detector.h"
17+
18+
#include <algorithm>
19+
#include <iostream>
20+
#include <sstream>
21+
22+
namespace xllm {
23+
namespace function_call {
24+
25+
Glm45Detector::Glm45Detector() : BaseFormatDetector() {
26+
bot_token_ = "<tool_call>";
27+
eot_token_ = "</tool_call>";
28+
29+
// Regex patterns for GLM-4.5 format
30+
func_call_regex_ = std::regex("<tool_call>[\\s\\S]*?</tool_call>",
31+
std::regex_constants::ECMAScript);
32+
func_detail_regex_ =
33+
std::regex("<tool_call>([^\\n]*)\\n([\\s\\S]*?)</tool_call>",
34+
std::regex_constants::ECMAScript);
35+
func_arg_regex_ = std::regex(
36+
"<arg_key>([\\s\\S]*?)</arg_key>\\s*<arg_value>([\\s\\S]*?)</arg_value>",
37+
std::regex_constants::ECMAScript);
38+
}
39+
40+
std::string Glm45Detector::trim_whitespace(std::string_view str) const {
41+
const char* whitespace = " \t\n\r";
42+
43+
size_t start = str.find_first_not_of(whitespace);
44+
if (start == std::string_view::npos) {
45+
return std::string{};
46+
}
47+
48+
size_t end = str.find_last_not_of(whitespace);
49+
50+
return std::string(str.substr(start, end - start + 1));
51+
}
52+
53+
bool Glm45Detector::has_tool_call(const std::string& text) {
54+
return text.find(bot_token_) != std::string::npos;
55+
}
56+
57+
StreamingParseResult Glm45Detector::detect_and_parse(
58+
const std::string& text,
59+
const std::vector<JsonTool>& tools) {
60+
size_t idx = text.find(bot_token_);
61+
std::string normal_text =
62+
(idx != std::string::npos) ? text.substr(0, idx) : text;
63+
64+
// Trim normal text
65+
if (!normal_text.empty()) {
66+
normal_text = trim_whitespace(normal_text);
67+
}
68+
69+
if (idx == std::string::npos) {
70+
return StreamingParseResult(normal_text, {});
71+
}
72+
73+
std::vector<ToolCallItem> calls;
74+
75+
try {
76+
std::sregex_iterator iter(text.begin(), text.end(), func_call_regex_);
77+
std::sregex_iterator end;
78+
79+
for (; iter != end; ++iter) {
80+
std::smatch match = *iter;
81+
std::string match_result = match.str();
82+
83+
// Parse function name and arguments
84+
std::smatch func_detail;
85+
if (std::regex_search(match_result, func_detail, func_detail_regex_)) {
86+
std::string func_name = func_detail[1].str();
87+
std::string func_args = func_detail[2].str();
88+
89+
// Parse arguments using regex
90+
std::unordered_map<std::string, nlohmann::json> arguments;
91+
std::sregex_iterator arg_iter(
92+
func_args.begin(), func_args.end(), func_arg_regex_);
93+
std::sregex_iterator arg_end;
94+
95+
for (; arg_iter != arg_end; ++arg_iter) {
96+
std::smatch arg_match = *arg_iter;
97+
if (arg_match.size() >= 3) {
98+
std::string arg_key = arg_match[1].str();
99+
std::string arg_value = arg_match[2].str();
100+
101+
arg_key = trim_whitespace(arg_key);
102+
103+
arg_value = trim_whitespace(arg_value);
104+
105+
try {
106+
nlohmann::json parsed_value = nlohmann::json::parse(arg_value);
107+
arguments[arg_key] = parsed_value;
108+
} catch (const nlohmann::json::parse_error&) {
109+
arguments[arg_key] = nlohmann::json(arg_value);
110+
}
111+
}
112+
}
113+
114+
// Create JSON object for parse_base_json
115+
nlohmann::json match_json;
116+
match_json["name"] = func_name;
117+
match_json["parameters"] = arguments;
118+
119+
auto parsed_calls = parse_base_json(match_json, tools);
120+
calls.insert(calls.end(), parsed_calls.begin(), parsed_calls.end());
121+
}
122+
}
123+
124+
return StreamingParseResult(normal_text, calls);
125+
126+
} catch (const std::exception& e) {
127+
LOG(ERROR) << "Error in GLM-4.5 detect_and_parse: " << e.what();
128+
return StreamingParseResult(text, {});
129+
}
130+
}
131+
132+
StreamingParseResult Glm45Detector::parse_streaming_increment(
133+
const std::string& new_text,
134+
const std::vector<JsonTool>& tools) {
135+
buffer_ += new_text;
136+
std::string current_text = buffer_;
137+
138+
size_t start = current_text.find(bot_token_);
139+
if (start == std::string::npos) {
140+
buffer_.clear();
141+
if (current_tool_id_ > 0) {
142+
current_text = "";
143+
}
144+
return StreamingParseResult(current_text, {});
145+
}
146+
147+
// Look for complete tool call
148+
size_t end = current_text.find(eot_token_);
149+
if (end != std::string::npos) {
150+
// Initialize state if this is the first tool call
151+
if (current_tool_id_ == -1) {
152+
current_tool_id_ = 0;
153+
prev_tool_call_arr_.clear();
154+
streamed_args_for_tool_.clear();
155+
streamed_args_for_tool_.push_back("");
156+
}
157+
158+
// Ensure we have enough entries in tracking arrays
159+
while (prev_tool_call_arr_.size() <= current_tool_id_) {
160+
prev_tool_call_arr_.push_back({});
161+
}
162+
while (streamed_args_for_tool_.size() <= current_tool_id_) {
163+
streamed_args_for_tool_.push_back("");
164+
}
165+
166+
// Parse the complete tool call
167+
std::string complete_call =
168+
current_text.substr(0, end + eot_token_.length());
169+
StreamingParseResult result = detect_and_parse(complete_call, tools);
170+
171+
if (!result.calls.empty()) {
172+
// Store tool call info for serving layer
173+
prev_tool_call_arr_[current_tool_id_]["name"] =
174+
result.calls[0].name.value_or("");
175+
prev_tool_call_arr_[current_tool_id_]["arguments"] =
176+
result.calls[0].parameters;
177+
streamed_args_for_tool_[current_tool_id_] = result.calls[0].parameters;
178+
179+
// Update tool index
180+
result.calls[0].tool_index = current_tool_id_;
181+
current_tool_id_++;
182+
}
183+
184+
// Update buffer with remaining text
185+
buffer_ = current_text.substr(end + eot_token_.length());
186+
return result;
187+
}
188+
189+
// Return normal text before tool call start
190+
std::string normal_text = current_text.substr(0, start);
191+
buffer_ = current_text.substr(start);
192+
return StreamingParseResult(normal_text, {});
193+
}
194+
195+
} // namespace function_call
196+
} // namespace xllm

0 commit comments

Comments
 (0)