Skip to content

Commit 9251ed5

Browse files
koushmondaylordchaunceyjiangywang96
authored
[Bugfix] Handle case when kimi ends reasoning with a tool call (vllm-project#33646)
Signed-off-by: Koushik Dutta <koushd@gmail.com> Co-authored-by: mondaylord <20212010046@fudan.edu.cn> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
1 parent e824937 commit 9251ed5

File tree

2 files changed

+230
-2
lines changed

2 files changed

+230
-2
lines changed

vllm/reasoning/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
"HunyuanA13BReasoningParser",
5454
),
5555
"kimi_k2": (
56-
"deepseek_v3_reasoning_parser",
57-
"DeepSeekV3ReasoningWithThinkingParser",
56+
"kimi_k2_reasoning_parser",
57+
"KimiK2ReasoningParser",
5858
),
5959
"minimax_m2": (
6060
"minimax_m2_reasoning_parser",
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Sequence
5+
6+
from transformers import PreTrainedTokenizerBase
7+
8+
from vllm.entrypoints.openai.chat_completion.protocol import (
9+
ChatCompletionRequest,
10+
)
11+
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
12+
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
13+
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
14+
15+
16+
class KimiK2ReasoningParser(ReasoningParser):
17+
"""
18+
Reasoning parser for Kimi K2 model.
19+
20+
The Kimi K2 model uses <think>...</think> tokens to denote reasoning text,
21+
and may implicitly end reasoning by starting a tool call section using
22+
<|tool_calls_section_begin|>.
23+
Thinking may also begin without a </think> token.
24+
25+
Kimi's thinking mode can be disabled via chat_template_kwargs.
26+
"""
27+
28+
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
29+
super().__init__(tokenizer, *args, **kwargs)
30+
31+
if not self.model_tokenizer:
32+
raise ValueError(
33+
"The model tokenizer must be passed to the ReasoningParser "
34+
"constructor during construction."
35+
)
36+
37+
# Check if thinking is disabled via chat_template_kwargs
38+
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
39+
thinking = bool(chat_kwargs.get("thinking", True))
40+
41+
# If thinking is not enabled, use identity parser to fall through
42+
if not thinking:
43+
self._identity_parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
44+
else:
45+
self._identity_parser = None
46+
47+
# Token definitions
48+
self._start_token = "<think>"
49+
self._end_token = "</think>"
50+
self._tool_section_start_token = "<|tool_calls_section_begin|>"
51+
52+
# Get token IDs
53+
self._start_token_id = self.vocab.get(self._start_token)
54+
self._end_token_id = self.vocab.get(self._end_token)
55+
self._tool_section_start_token_id = self.vocab.get(
56+
self._tool_section_start_token
57+
)
58+
59+
if self._start_token_id is None or self._end_token_id is None:
60+
raise RuntimeError(
61+
"KimiK2ReasoningParser could not locate think start/end "
62+
"tokens in the tokenizer!"
63+
)
64+
65+
def _is_identity_mode(self) -> bool:
66+
"""Check if parser is in identity mode (no reasoning extraction)."""
67+
return self._identity_parser is not None
68+
69+
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
70+
"""
71+
Check if the reasoning content ends in the input_ids.
72+
73+
Reasoning ends when we see either:
74+
1. The end token (</think>)
75+
2. The tool section start token (<|tool_calls_section_begin|>)
76+
"""
77+
if self._is_identity_mode():
78+
return self._identity_parser.is_reasoning_end(input_ids)
79+
80+
start_token_id = self._start_token_id
81+
end_token_id = self._end_token_id
82+
tool_section_start_token_id = self._tool_section_start_token_id
83+
84+
for i in range(len(input_ids) - 1, -1, -1):
85+
if input_ids[i] == start_token_id:
86+
return False
87+
if input_ids[i] == end_token_id:
88+
return True
89+
# Implicit reasoning end via tool call section
90+
if (
91+
tool_section_start_token_id is not None
92+
and input_ids[i] == tool_section_start_token_id
93+
):
94+
return True
95+
return False
96+
97+
def is_reasoning_end_streaming(
98+
self, input_ids: Sequence[int], delta_ids: Sequence[int]
99+
) -> bool:
100+
"""
101+
Check if the reasoning content ends in the input_ids on a decode step.
102+
"""
103+
if self._is_identity_mode():
104+
return self._identity_parser.is_reasoning_end_streaming(
105+
input_ids, delta_ids
106+
)
107+
108+
# Check for explicit end token or implicit tool section start in delta
109+
if self._end_token_id in delta_ids:
110+
return True
111+
return (
112+
self._tool_section_start_token_id is not None
113+
and self._tool_section_start_token_id in delta_ids
114+
)
115+
116+
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
117+
"""
118+
Extract content token ids from the input_ids.
119+
"""
120+
if self._is_identity_mode():
121+
return self._identity_parser.extract_content_ids(input_ids)
122+
123+
if self._end_token_id in input_ids:
124+
end_token_index = (
125+
len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id)
126+
)
127+
128+
if end_token_index != -1:
129+
return input_ids[end_token_index + 1 :]
130+
131+
if (
132+
self._tool_section_start_token_id is not None
133+
and self._tool_section_start_token_id in input_ids
134+
):
135+
tool_section_index = (
136+
len(input_ids)
137+
- 1
138+
- input_ids[::-1].index(self._tool_section_start_token_id)
139+
)
140+
141+
if tool_section_index != -1:
142+
return input_ids[tool_section_index:]
143+
144+
# still reasoning (no content)
145+
return []
146+
147+
def extract_reasoning(
148+
self, model_output: str, request: ChatCompletionRequest
149+
) -> tuple[str | None, str | None]:
150+
"""
151+
Extract reasoning content from the model output.
152+
"""
153+
if self._is_identity_mode():
154+
return self._identity_parser.extract_reasoning(model_output, request)
155+
156+
# thinking does not require a think start token but consume it if present
157+
start_token_index = model_output.find(self._start_token)
158+
start_token_index = 0 if start_token_index != 0 else len(self._start_token)
159+
end_token_index = model_output.find(self._end_token)
160+
161+
if end_token_index != -1:
162+
return (
163+
model_output[start_token_index:end_token_index],
164+
model_output[end_token_index + len(self._end_token) :] or None,
165+
)
166+
167+
tool_section_index = model_output.find(self._tool_section_start_token)
168+
if tool_section_index != -1:
169+
return (
170+
model_output[start_token_index:tool_section_index],
171+
model_output[tool_section_index:] or None,
172+
)
173+
174+
# still reasoning (no content)
175+
return (
176+
model_output[start_token_index:],
177+
None,
178+
)
179+
180+
def extract_reasoning_streaming(
181+
self,
182+
previous_text: str,
183+
current_text: str,
184+
delta_text: str,
185+
previous_token_ids: Sequence[int],
186+
current_token_ids: Sequence[int],
187+
delta_token_ids: Sequence[int],
188+
) -> DeltaMessage | None:
189+
"""
190+
Extract reasoning content from a delta message during streaming.
191+
"""
192+
if self._is_identity_mode():
193+
return self._identity_parser.extract_reasoning_streaming(
194+
previous_text,
195+
current_text,
196+
delta_text,
197+
previous_token_ids,
198+
current_token_ids,
199+
delta_token_ids,
200+
)
201+
202+
# If reasoning has already ended in previous tokens, this is content
203+
if self.is_reasoning_end(previous_token_ids):
204+
return DeltaMessage(content=delta_text)
205+
206+
# Skip single special tokens
207+
if len(delta_token_ids) == 1 and delta_token_ids[0] in [
208+
self._start_token_id,
209+
self._end_token_id,
210+
]:
211+
return None
212+
213+
if self._end_token_id in delta_token_ids:
214+
end_index = delta_text.find(self._end_token)
215+
reasoning = delta_text[:end_index]
216+
content = delta_text[end_index + len(self._end_token) :]
217+
return DeltaMessage(
218+
reasoning=reasoning, content=content if content else None
219+
)
220+
221+
if self._tool_section_start_token_id in delta_token_ids:
222+
tool_index = delta_text.find(self._tool_section_start_token)
223+
reasoning = delta_text[:tool_index]
224+
content = delta_text[tool_index:]
225+
return DeltaMessage(reasoning=reasoning, content=content)
226+
227+
# still reasoning (no end token)
228+
return DeltaMessage(reasoning=delta_text)

0 commit comments

Comments
 (0)