Skip to content

Commit a4227cf

Browse files
authored
[None][feat] Support Qwen3 reasoning parser (#8000)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent 0acd10e commit a4227cf

File tree

4 files changed

+164
-110
lines changed

4 files changed

+164
-110
lines changed
Lines changed: 73 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Dict, Optional
3+
from typing import Type
44

55

66
@dataclass
77
class ReasoningParserResult:
8-
9-
def __init__(self,
10-
in_reasoning: bool,
11-
content: Optional[str] = None,
12-
reasoning_content: Optional[str] = None):
13-
self.in_reasoning = in_reasoning
14-
self.content = content
15-
self.reasoning_content = reasoning_content
8+
content: str = ""
9+
reasoning_content: str = ""
1610

1711

1812
class BaseReasoningParser(ABC):
@@ -34,62 +28,99 @@ class DeepSeekR1Parser(BaseReasoningParser):
3428
treat all the text before the </think> tag as `reasoning_content` and the text after as `content`.
3529
"""
3630

37-
def __init__(self):
31+
def __init__(self, reasoning_at_start: bool = False) -> None:
32+
self.reasoning_start = "<think>"
3833
self.reasoning_end = "</think>"
39-
self.in_reasoning = True
34+
self.reasoning_at_start = reasoning_at_start
35+
self.in_reasoning = self.reasoning_at_start
36+
self._buffer = ""
4037

4138
def _create_reasoning_end_result(self, content: str,
4239
reasoning_content: str):
4340
if len(content) == 0:
4441
reasoning_parser_result = ReasoningParserResult(
45-
True, reasoning_content=reasoning_content)
42+
reasoning_content=reasoning_content)
4643
elif len(reasoning_content) == 0:
47-
reasoning_parser_result = ReasoningParserResult(False,
48-
content=content)
44+
reasoning_parser_result = ReasoningParserResult(content=content)
4945
else:
5046
reasoning_parser_result = ReasoningParserResult(
51-
False, content=content, reasoning_content=reasoning_content)
47+
content=content, reasoning_content=reasoning_content)
5248
return reasoning_parser_result
5349

5450
def parse(self, text: str) -> ReasoningParserResult:
55-
if self.reasoning_end not in text:
56-
return ReasoningParserResult(True, reasoning_content=text)
57-
58-
splits = text.split(self.reasoning_end, maxsplit=1)
59-
reasoning_content = splits[0]
60-
content = splits[1]
61-
62-
reasoning_parser_result = self._create_reasoning_end_result(
63-
content, reasoning_content)
64-
return reasoning_parser_result
51+
if not self.reasoning_at_start:
52+
splits = text.partition(self.reasoning_start)
53+
if splits[1] == "":
54+
# no reasoning start tag found
55+
return ReasoningParserResult(content=text)
56+
# reasoning start tag found
57+
# text before reasoning start tag is dropped
58+
text = splits[2]
59+
splits = text.partition(self.reasoning_end)
60+
reasoning_content, content = splits[0], splits[2]
61+
return ReasoningParserResult(content=content,
62+
reasoning_content=reasoning_content)
6563

6664
def parse_delta(self, delta_text: str) -> ReasoningParserResult:
67-
if self.in_reasoning and self.reasoning_end in delta_text:
65+
self._buffer += delta_text
66+
delta_text = self._buffer
67+
reasoning_content = None
68+
content = None
69+
if (self.reasoning_start.startswith(delta_text)
70+
or self.reasoning_end.startswith(delta_text)):
71+
# waiting for more text to determine if it's a reasoning start or end tag
72+
return ReasoningParserResult()
73+
74+
if not self.in_reasoning:
75+
begin_idx = delta_text.find(self.reasoning_start)
76+
if begin_idx == -1:
77+
self._buffer = ""
78+
return ReasoningParserResult(content=delta_text)
79+
self.in_reasoning = True
80+
# set reasoning_content, will be processed by the next block
81+
reasoning_content = delta_text[begin_idx +
82+
len(self.reasoning_start):]
83+
84+
if self.in_reasoning:
85+
delta_text = reasoning_content if reasoning_content is not None else delta_text
6886
end_idx = delta_text.find(self.reasoning_end)
87+
if end_idx == -1:
88+
last_idx = delta_text.rfind(self.reasoning_end[0])
89+
if last_idx != -1 and self.reasoning_end.startswith(
90+
delta_text[last_idx:]):
91+
self._buffer = delta_text[last_idx:]
92+
reasoning_content = delta_text[:last_idx]
93+
else:
94+
self._buffer = ""
95+
reasoning_content = delta_text
96+
return ReasoningParserResult(
97+
reasoning_content=reasoning_content)
6998
reasoning_content = delta_text[:end_idx]
7099
content = delta_text[end_idx + len(self.reasoning_end):]
71-
reasoning_parser_result = self._create_reasoning_end_result(
72-
content, reasoning_content)
73100
self.in_reasoning = False
74-
return reasoning_parser_result
75-
76-
if self.in_reasoning:
77-
return ReasoningParserResult(self.in_reasoning,
78-
reasoning_content=delta_text)
79-
80-
# not self.in_reasoning:
81-
return ReasoningParserResult(self.in_reasoning, content=delta_text)
101+
self._buffer = ""
102+
return ReasoningParserResult(content=content,
103+
reasoning_content=reasoning_content)
104+
raise RuntimeError(
105+
"Unreachable code reached in `DeepSeekR1Parser.parse_delta`")
82106

83107

84108
class ReasoningParserFactory:
85-
parsers: Dict[str, BaseReasoningParser] = {
109+
parsers: dict[str, Type[BaseReasoningParser]] = {
86110
"deepseek-r1": DeepSeekR1Parser,
111+
"qwen3": DeepSeekR1Parser,
87112
}
88113

89114
@staticmethod
90115
def create_reasoning_parser(reasoning_parser: str) -> BaseReasoningParser:
91-
if reasoning_parser not in ReasoningParserFactory.parsers:
92-
raise ValueError(f"Invalid reasoning_parser: {reasoning_parser}")
93-
reasoning_parser_class = ReasoningParserFactory.parsers.get(
94-
reasoning_parser.lower())
95-
return reasoning_parser_class()
116+
try:
117+
reasoning_parser_class = ReasoningParserFactory.parsers[
118+
reasoning_parser.lower()]
119+
if reasoning_parser == "deepseek-r1":
120+
return reasoning_parser_class(reasoning_at_start=True)
121+
return reasoning_parser_class()
122+
except KeyError as e:
123+
raise ValueError(
124+
f"Invalid reasoning parser: {reasoning_parser}\n"
125+
f"Supported parsers: {list(ReasoningParserFactory.parsers.keys())}"
126+
) from e

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
9595

9696

9797
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
98-
streaming: bool) -> Tuple[bool, str, str]:
98+
streaming: bool) -> Tuple[str, str]:
9999
reasoning_parser = None
100100
if args.reasoning_parser is not None:
101101
if output_index not in args.reasoning_parser_dict:
@@ -104,17 +104,16 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
104104
args.reasoning_parser)
105105
reasoning_parser = args.reasoning_parser_dict[output_index]
106106

107-
in_reasoning = False
108107
if reasoning_parser is not None:
109108
if not streaming:
110109
result = reasoning_parser.parse(text)
111110
else:
112111
result = reasoning_parser.parse_delta(text)
113-
in_reasoning, content, reasoning_content = result.in_reasoning, result.content, result.reasoning_content
112+
content, reasoning_content = result.content, result.reasoning_content
114113
else:
115-
in_reasoning, content, reasoning_content = False, text, None
114+
content, reasoning_content = text, ""
116115

117-
return in_reasoning, content, reasoning_content
116+
return content, reasoning_content
118117

119118

120119
@nvtx_range_debug("chat_stream_post_processor")
@@ -123,8 +122,8 @@ def chat_stream_post_processor(rsp: GenerationResultBase,
123122

124123
def yield_first_chat(num_tokens: int,
125124
idx: int,
126-
role: str = None,
127-
content: str = None):
125+
role: str | None = None,
126+
content: str | None = None):
128127
choice_data = ChatCompletionResponseStreamChoice(index=idx,
129128
delta=DeltaMessage(
130129
role=role,
@@ -171,7 +170,7 @@ def yield_first_chat(num_tokens: int,
171170

172171
delta_text = output.text_diff
173172

174-
in_reasoning, delta_text, reasoning_delta_text = apply_reasoning_parser(
173+
delta_text, reasoning_delta_text = apply_reasoning_parser(
175174
args, i, delta_text, True)
176175

177176
if args.tool_choice and type(
@@ -181,12 +180,8 @@ def yield_first_chat(num_tokens: int,
181180
name=args.tool_choice.function.name, arguments=delta_text))
182181
])
183182
else:
184-
if in_reasoning:
185-
delta_message = DeltaMessage(
186-
reasoning_content=reasoning_delta_text)
187-
else:
188-
delta_message = DeltaMessage(
189-
content=delta_text, reasoning_content=reasoning_delta_text)
183+
delta_message = DeltaMessage(content=delta_text,
184+
reasoning_content=reasoning_delta_text)
190185

191186
choice = ChatCompletionResponseStreamChoice(
192187
index=i,
@@ -239,8 +234,8 @@ def chat_response_post_processor(
239234
choices: List[ChatCompletionResponseChoice] = []
240235
role = args.role
241236
for output in rsp.outputs:
242-
_, text, reasoning_text = apply_reasoning_parser(
243-
args, output.index, output.text, False)
237+
text, reasoning_text = apply_reasoning_parser(args, output.index,
238+
output.text, False)
244239

245240
if args.tool_choice and isinstance(args.tool_choice,
246241
ChatCompletionNamedToolChoiceParam):
@@ -252,8 +247,6 @@ def chat_response_post_processor(
252247
name=args.tool_choice.function.name, arguments=text))
253248
])
254249
else:
255-
if text is None:
256-
text = ""
257250
message = ChatMessage(role=role,
258251
content=text,
259252
reasoning_content=reasoning_text)

tests/unittest/llmapi/apps/_test_openai_reasoning.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
pytestmark = pytest.mark.threadleak(enabled=False)
1010

1111

12-
@pytest.fixture(scope="module", ids=["DeepSeek-R1-Distill-Qwen-1.5B"])
13-
def model_name() -> str:
14-
return "DeepSeek-R1-Distill-Qwen-1.5B"
12+
# yapf: disable
13+
@pytest.fixture(scope="module",
14+
params=["DeepSeek-R1-Distill-Qwen-1.5B",
15+
"Qwen3/Qwen3-0.6B"])
16+
def model_name(request) -> str:
17+
return request.param
18+
# yapf: enable
1519

1620

1721
@pytest.fixture(scope="module", params=["trt", "pytorch"])
@@ -21,12 +25,19 @@ def backend(request):
2125

2226
@pytest.fixture(scope="module")
2327
def server(model_name: str, backend: str):
28+
# Skip specific model/backend combinations
29+
if model_name == "Qwen3/Qwen3-0.6B" and backend == "trt":
30+
pytest.skip("Qwen3 model not supported with trt backend")
31+
2432
model_path = get_model_path(model_name)
2533
args = ["--backend", f"{backend}"]
2634
max_beam_width = 1 if backend == "pytorch" else 2
2735
args.extend(["--max_beam_width", str(max_beam_width)])
2836
args.extend(["--max_batch_size", "2", "--max_seq_len", "1024"])
29-
args.extend(["--reasoning_parser", "deepseek-r1"])
37+
if model_name.startswith("Qwen3"):
38+
args.extend(["--reasoning_parser", "qwen3"])
39+
else:
40+
args.extend(["--reasoning_parser", "deepseek-r1"])
3041
with RemoteOpenAIServer(model_path, args) as remote_server:
3142
yield remote_server
3243

@@ -51,16 +62,10 @@ def test_reasoning_parser(client: openai.OpenAI, model_name: str, backend: str):
5162
extra_body=extra_body,
5263
)
5364

54-
if backend == "pytorch":
55-
assert len(resp.choices) == n
56-
for resp_choice in resp.choices:
57-
assert len(resp_choice.message.content) > 0
58-
assert len(resp_choice.message.reasoning_content) > 0
59-
else:
60-
assert len(resp.choices) == n
61-
for resp_choice in resp.choices:
62-
assert len(resp_choice.message.content) > 0
63-
assert len(resp_choice.message.reasoning_content) > 0
65+
assert len(resp.choices) == n
66+
for resp_choice in resp.choices:
67+
assert len(resp_choice.message.content) > 0
68+
assert len(resp_choice.message.reasoning_content) > 0
6469

6570

6671
@pytest.fixture(scope="module")
@@ -78,9 +83,9 @@ async def process_stream(
7883
delta = choice.delta.dict()
7984
content = delta.get("content", None)
8085
reasoning_content = delta.get("reasoning_content", None)
81-
if content is not None:
86+
if content:
8287
content_chunks.append(content)
83-
if reasoning_content is not None:
88+
if reasoning_content:
8489
reasoning_content_chunks.append(reasoning_content)
8590
return (content_chunks, reasoning_content_chunks)
8691

@@ -105,12 +110,17 @@ async def test_reasoning_parser_streaming(async_client: openai.AsyncOpenAI,
105110
stream = await async_client.chat.completions.create(
106111
model=model_name,
107112
messages=messages,
108-
max_completion_tokens=1,
113+
max_completion_tokens=2,
109114
temperature=0.0,
110115
stream=True,
111116
)
112117

113118
content_chunks, reasoning_content_chunks = await process_stream(
114119
stream=stream)
115120
assert len(content_chunks) == 0
116-
assert len(reasoning_content_chunks) == 1
121+
if model_name.startswith("Qwen3"):
122+
# First token would be <think>
123+
assert len(reasoning_content_chunks) == 1
124+
else:
125+
# <think> is in chat template
126+
assert len(reasoning_content_chunks) == 2

0 commit comments

Comments
 (0)