Skip to content

Commit 2fec53d

Browse files
authored
[TRTLLM-9637][feat] Support tool parser for Kimi K2 (NVIDIA#9830)
Signed-off-by: Junyi Xu <[email protected]>
1 parent 9df4dad commit 2fec53d

File tree

5 files changed

+374
-1
lines changed

5 files changed

+374
-1
lines changed

tensorrt_llm/serve/openai_server.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def __init__(self,
152152
else:
153153
self.use_harmony = (self.model_config.model_type == "gpt_oss")
154154

155+
self.tool_call_id_type = "random" # default tool call id type is random
156+
if self.model_config.model_type == "kimi_k2":
157+
self.tool_call_id_type = "kimi_k2"
158+
155159
# as disagg-worker
156160
self.disagg_cluster_storage = None
157161
self.disagg_cluster_worker = None
@@ -554,6 +558,7 @@ async def create_chat_response(
554558

555559
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
556560
postproc_args.tool_parser = self.tool_parser
561+
postproc_args.tool_call_id_type = self.tool_call_id_type
557562
if conversation and conversation[-1].get(
558563
"content") and conversation[-1].get("role") == get_role():
559564
postproc_args.last_message_content = conversation[-1]["content"]

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class ChatPostprocArgs(PostprocArgs):
5454
default_factory=dict)
5555
tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict)
5656
has_tool_call: dict[int, bool] = field(default_factory=dict)
57+
tool_call_id_type: str = "random"
5758

5859
@classmethod
5960
def from_request(cls, request: ChatCompletionRequest):
@@ -223,7 +224,10 @@ def yield_first_chat(num_tokens: int,
223224
# Tool call ID should be generated only once per tool call
224225
if call_item.name:
225226
# First chunk: include ID and function name
226-
tool_call_id = make_tool_call_id()
227+
tool_call_id = make_tool_call_id(
228+
id_type=args.tool_call_id_type,
229+
func_name=call_item.name,
230+
idx=call_item.tool_index)
227231
function_name = call_item.name
228232
else:
229233
# Subsequent chunks: null ID and name for argument deltas
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Adapted from https://github.com/sgl-project/sglang/blob/083629c23564e1a64deaa052f1df5c5d914358d8/python/sglang/srt/function_call/kimik2_detector.py
2+
import json
3+
import re
4+
from typing import List
5+
6+
from tensorrt_llm.logger import logger
7+
8+
from ..openai_protocol import ChatCompletionToolsParam as Tool
9+
from .base_tool_parser import BaseToolParser
10+
from .core_types import StreamingParseResult, StructureInfo, ToolCallItem, _GetInfoFunc
11+
12+
13+
class KimiK2ToolParser(BaseToolParser):
14+
"""Detector for Kimi K2 model function call format.
15+
16+
Format Structure:
17+
```
18+
<|tool_calls_section_begin|>
19+
<|tool_call_begin|>functions.{func_name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|>
20+
<|tool_calls_section_end|>
21+
```
22+
23+
Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
24+
"""
25+
26+
def __init__(self):
27+
super().__init__()
28+
29+
self.bot_token: str = "<|tool_calls_section_begin|>"
30+
self.eot_token: str = "<|tool_calls_section_end|>"
31+
32+
self.tool_call_start_token: str = "<|tool_call_begin|>"
33+
self.tool_call_end_token: str = "<|tool_call_end|>"
34+
35+
self.tool_call_regex = re.compile(
36+
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
37+
)
38+
39+
self.stream_tool_call_portion_regex = re.compile(
40+
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
41+
)
42+
43+
self._last_arguments = ""
44+
45+
# Robust parser for ids like "functions.search:0" or fallback "search:0"
46+
self.tool_call_id_regex = re.compile(r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$")
47+
48+
def has_tool_call(self, text: str) -> bool:
49+
"""Check if the text contains a KimiK2 format tool call."""
50+
return self.bot_token in text
51+
52+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
53+
"""One-time parsing: Detects and parses tool calls in the provided text.
54+
55+
:param text: The complete text to parse.
56+
:param tools: List of available tools.
57+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
58+
"""
59+
if self.bot_token not in text:
60+
return StreamingParseResult(normal_text=text, calls=[])
61+
try:
62+
# there are two possible captures - between tags, or between a
63+
# tag and end-of-string so the result of
64+
# findall is an array of tuples where one is a function call and
65+
# the other is None
66+
function_call_tuples = self.tool_call_regex.findall(text)
67+
tool_indices = self._get_tool_indices(tools)
68+
69+
logger.debug("function_call_tuples: %s", function_call_tuples)
70+
71+
tool_calls = []
72+
for match in function_call_tuples:
73+
function_id, function_args = match
74+
m = self.tool_call_id_regex.match(function_id)
75+
if not m:
76+
logger.warning("Unexpected tool_call_id format: %s", function_id)
77+
continue
78+
function_name = m.group("name")
79+
function_idx = int(m.group("index"))
80+
81+
if function_name not in tool_indices:
82+
logger.warning(f"Model attempted to call undefined function: {function_name}")
83+
continue
84+
85+
logger.debug(f"function_name {function_name}")
86+
87+
tool_calls.append(
88+
ToolCallItem(
89+
tool_index=function_idx,
90+
name=function_name,
91+
parameters=function_args,
92+
)
93+
)
94+
95+
content = text[: text.find(self.bot_token)]
96+
return StreamingParseResult(normal_text=content, calls=tool_calls)
97+
98+
except Exception as e:
99+
logger.error(f"Error in detect_and_parse: {e}")
100+
# return the normal text if parsing fails
101+
return StreamingParseResult(normal_text=text)
102+
103+
def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
104+
"""Streaming incremental parsing tool calls for KimiK2 format."""
105+
self._buffer += new_text
106+
current_text = self._buffer
107+
108+
# Check if we have a tool call (either the start token or individual tool call)
109+
has_tool_call = self.bot_token in current_text or self.tool_call_start_token in current_text
110+
111+
if not has_tool_call:
112+
self._buffer = ""
113+
for e_token in [self.eot_token, self.tool_call_end_token]:
114+
if e_token in new_text:
115+
new_text = new_text.replace(e_token, "")
116+
return StreamingParseResult(normal_text=new_text)
117+
118+
if not hasattr(self, "_tool_indices"):
119+
self._tool_indices = self._get_tool_indices(tools)
120+
121+
calls: list[ToolCallItem] = []
122+
try:
123+
match = self.stream_tool_call_portion_regex.search(current_text)
124+
if match:
125+
function_id = match.group("tool_call_id")
126+
function_args = match.group("function_arguments")
127+
128+
m = self.tool_call_id_regex.match(function_id)
129+
if not m:
130+
logger.warning("Unexpected tool_call_id format: %s", function_id)
131+
return StreamingParseResult(normal_text="", calls=calls)
132+
function_name = m.group("name")
133+
134+
# Initialize state if this is the first tool call
135+
if self.current_tool_id == -1:
136+
self.current_tool_id = 0
137+
self.prev_tool_call_arr = []
138+
self.streamed_args_for_tool = [""]
139+
140+
# Ensure we have enough entries in our tracking arrays
141+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
142+
self.prev_tool_call_arr.append({})
143+
while len(self.streamed_args_for_tool) <= self.current_tool_id:
144+
self.streamed_args_for_tool.append("")
145+
146+
if not self.current_tool_name_sent:
147+
calls.append(
148+
ToolCallItem(
149+
tool_index=self.current_tool_id,
150+
name=function_name,
151+
parameters="",
152+
)
153+
)
154+
self.current_tool_name_sent = True
155+
# Store the tool call info for serving layer completions endpoint
156+
self.prev_tool_call_arr[self.current_tool_id] = {
157+
"name": function_name,
158+
"arguments": {},
159+
}
160+
else:
161+
argument_diff = (
162+
function_args[len(self._last_arguments) :]
163+
if function_args.startswith(self._last_arguments)
164+
else function_args
165+
)
166+
167+
parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
168+
169+
if parsed_args_diff:
170+
calls.append(
171+
ToolCallItem(
172+
tool_index=self.current_tool_id,
173+
name=None,
174+
parameters=parsed_args_diff,
175+
)
176+
)
177+
self._last_arguments += argument_diff
178+
self.streamed_args_for_tool[self.current_tool_id] += parsed_args_diff
179+
180+
parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
181+
try:
182+
parsed_args = json.loads(parsed_args)
183+
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = parsed_args
184+
185+
# Find the end of the current tool call and remove only that part from buffer
186+
tool_call_end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
187+
match = re.search(tool_call_end_pattern, current_text, re.DOTALL)
188+
if match:
189+
# Remove the completed tool call from buffer, keep any remaining content
190+
self._buffer = current_text[match.end() :]
191+
else:
192+
self._buffer = ""
193+
194+
result = StreamingParseResult(normal_text="", calls=calls)
195+
self.current_tool_id += 1
196+
self._last_arguments = ""
197+
self.current_tool_name_sent = False
198+
return result
199+
except json.JSONDecodeError:
200+
pass
201+
202+
return StreamingParseResult(normal_text="", calls=calls)
203+
204+
except Exception as e:
205+
logger.error(f"Error in parse_streaming_increment: {e}")
206+
return StreamingParseResult(normal_text=current_text)
207+
208+
def structure_info(self) -> _GetInfoFunc:
209+
"""Return function that creates StructureInfo for guided generation."""
210+
211+
def get_info(name: str) -> StructureInfo:
212+
return StructureInfo(
213+
begin=f"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0<|tool_call_argument_begin|>",
214+
end="<|tool_call_end|><|tool_calls_section_end|>",
215+
trigger="<|tool_calls_section_begin|>",
216+
)
217+
218+
return get_info

tensorrt_llm/serve/tool_parser/tool_parser_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Type
22

33
from .base_tool_parser import BaseToolParser
4+
from .kimi_k2_tool_parser import KimiK2ToolParser
45
from .qwen3_coder_parser import Qwen3CoderToolParser
56
from .qwen3_tool_parser import Qwen3ToolParser
67

@@ -9,6 +10,7 @@ class ToolParserFactory:
910
parsers: dict[str, Type[BaseToolParser]] = {
1011
"qwen3": Qwen3ToolParser,
1112
"qwen3_coder": Qwen3CoderToolParser,
13+
"kimi_k2": KimiK2ToolParser,
1214
}
1315

1416
@staticmethod

0 commit comments

Comments
 (0)