Skip to content

Commit 3230fbe

Browse files
authored
[None][feat] Update reasoning parser for nano-v3 (#9944)
Signed-off-by: Wanli Jiang <[email protected]>
1 parent 9e7182b commit 3230fbe

File tree

3 files changed

+100
-4
lines changed

3 files changed

+100
-4
lines changed

tensorrt_llm/llmapi/reasoning_parser.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Type
3+
from typing import Any, Optional, Type
44

55

66
@dataclass
@@ -109,15 +109,28 @@ class ReasoningParserFactory:
109109
parsers: dict[str, Type[BaseReasoningParser]] = {
110110
"deepseek-r1": DeepSeekR1Parser,
111111
"qwen3": DeepSeekR1Parser,
112+
"nano-v3": DeepSeekR1Parser,
112113
}
113114

114115
@staticmethod
115-
def create_reasoning_parser(reasoning_parser: str) -> BaseReasoningParser:
116+
def create_reasoning_parser(
117+
reasoning_parser: str,
118+
chat_template_kwargs: Optional[dict[str, Any]] = None
119+
) -> BaseReasoningParser:
116120
try:
117121
reasoning_parser_class = ReasoningParserFactory.parsers[
118122
reasoning_parser.lower()]
119123
if reasoning_parser == "deepseek-r1":
120124
return reasoning_parser_class(reasoning_at_start=True)
125+
elif reasoning_parser == "nano-v3":
126+
# Note: If the model is with reasoning (default behavior), `reasoning_at_start` should be True, and the starting response should be parsed into `reasoning_content`.
127+
# While the model is without reasoning, `reasoning_at_start` should be False to parse the response into `content` fields.
128+
is_reasoning_model = True
129+
if isinstance(chat_template_kwargs, dict):
130+
is_reasoning_model = chat_template_kwargs.get(
131+
"enable_thinking", True)
132+
return reasoning_parser_class(
133+
reasoning_at_start=is_reasoning_model)
121134
return reasoning_parser_class()
122135
except KeyError as e:
123136
raise ValueError(

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Literal, Optional, Tuple, Union
2+
from typing import Any, List, Literal, Optional, Tuple, Union
33

44
from .._utils import nvtx_range_debug
55
from ..executor import (DetokenizedGenerationResultBase, GenerationResult,
@@ -55,6 +55,7 @@ class ChatPostprocArgs(PostprocArgs):
5555
tool_parser_dict: dict[int, BaseToolParser] = field(default_factory=dict)
5656
has_tool_call: dict[int, bool] = field(default_factory=dict)
5757
tool_call_id_type: str = "random"
58+
chat_template_kwargs: Optional[dict[str, Any]] = None
5859

5960
@classmethod
6061
def from_request(cls, request: ChatCompletionRequest):
@@ -69,6 +70,7 @@ def from_request(cls, request: ChatCompletionRequest):
6970
stream_options=request.stream_options,
7071
return_logprobs=bool(request.logprobs),
7172
top_logprobs=bool(request.top_logprobs),
73+
chat_template_kwargs=request.chat_template_kwargs,
7274
)
7375

7476

@@ -108,9 +110,10 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
108110
reasoning_parser = None
109111
if args.reasoning_parser is not None:
110112
if output_index not in args.reasoning_parser_dict:
113+
chat_template_kwargs = getattr(args, "chat_template_kwargs", None)
111114
args.reasoning_parser_dict[
112115
output_index] = ReasoningParserFactory.create_reasoning_parser(
113-
args.reasoning_parser)
116+
args.reasoning_parser, chat_template_kwargs)
114117
reasoning_parser = args.reasoning_parser_dict[output_index]
115118

116119
if reasoning_parser is not None:
@@ -501,13 +504,15 @@ class ChatCompletionPostprocArgs(PostprocArgs):
501504
tool_choice: Optional[Union[Literal["none", "auto"],
502505
ChatCompletionNamedToolChoiceParam]]
503506
request_id: Optional[int] = None
507+
chat_template_kwargs: Optional[dict[str, Any]] = None
504508

505509
@classmethod
506510
def from_request(cls, request: ChatCompletionRequest):
507511
return cls(
508512
model=request.model,
509513
tools=request.tools,
510514
tool_choice=request.tool_choice,
515+
chat_template_kwargs=request.chat_template_kwargs,
511516
)
512517

513518

tests/unittest/llmapi/test_reasoning_parser.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,81 @@ def test_qwen3_reasoning_parser_stream(delta_texts: list, content: list,
7171
result = reasoning_parser.parse_delta(delta_text)
7272
assert result.content == content[i]
7373
assert result.reasoning_content == reasoning_context[i]
74+
75+
76+
@pytest.mark.parametrize(
77+
("text", "content", "reasoning_context", "chat_template_kwargs"),
78+
[
79+
("a b", "", "a b", None),
80+
(f"{R1_END} a b", " a b", "", None),
81+
(f"a {R1_END} b", " b", "a ", None),
82+
(f"a b {R1_END}", "", "a b ", None),
83+
(f"{R1_START} a {R1_END} b", " b", f"{R1_START} a ", None),
84+
# All without reasoning_context.
85+
("a b", "a b", "", {
86+
"enable_thinking": False
87+
}),
88+
(f"{R1_END} a b", f"{R1_END} a b", "", {
89+
"enable_thinking": False
90+
}),
91+
(f"a {R1_END} b", f"a {R1_END} b", "", {
92+
"enable_thinking": False
93+
}),
94+
(f"a b {R1_END}", f"a b {R1_END}", "", {
95+
"enable_thinking": False
96+
}),
97+
])
98+
def test_nano_v3_reasoning_parser(text: str, content: str,
99+
reasoning_context: str,
100+
chat_template_kwargs: dict):
101+
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
102+
"nano-v3", chat_template_kwargs)
103+
result = reasoning_parser.parse(text)
104+
print(f"text: {text}, result: {result}")
105+
assert result.content == content
106+
assert result.reasoning_content == reasoning_context
107+
108+
109+
@pytest.mark.parametrize(
110+
("delta_texts", "content", "reasoning_context", "chat_template_kwargs"),
111+
[
112+
(["a", "b"], ["", ""], ["a", "b"], None),
113+
([R1_END, "a", "b"], ["", "a", "b"], ["", "", ""], None),
114+
(["a", R1_END, "b"], ["", "", "b"], ["a", "", ""], None),
115+
(["a", "b", R1_END], ["", "", ""], ["a", "b", ""], None),
116+
(["a", f"l{R1_END}", "b"], ["", "", "b"], ["a", "l", ""], None),
117+
(["a", f"l{R1_END}r", "b"], ["", "r", "b"], ["a", "l", ""], None),
118+
(["a", f"{R1_END}r", "b"], ["", "r", "b"], ["a", "", ""], None),
119+
# All without reasoning_context.
120+
(["a", "b"], ["a", "b"], ["", ""], {
121+
"enable_thinking": False
122+
}),
123+
([R1_END, "a", "b"], ["", f"{R1_END}a", "b"], ["", "", ""], {
124+
"enable_thinking": False
125+
}),
126+
(["a", R1_END, "b"], ["a", "", f"{R1_END}b"], ["", "", ""], {
127+
"enable_thinking": False
128+
}),
129+
(["a", "b", R1_END], ["a", "b", ""], ["", "", ""], {
130+
"enable_thinking": False
131+
}),
132+
(["a", f"l{R1_END}", "b"], ["a", f"l{R1_END}", "b"], ["", "", ""], {
133+
"enable_thinking": False
134+
}),
135+
(["a", f"l{R1_END}r", "b"], ["a", f"l{R1_END}r", "b"], ["", "", ""], {
136+
"enable_thinking": False
137+
}),
138+
(["a", f"{R1_END}r", "b"], ["a", f"{R1_END}r", "b"], ["", "", ""], {
139+
"enable_thinking": False
140+
}),
141+
])
142+
def test_nano_v3_reasoning_parser_stream(delta_texts: list, content: list,
143+
reasoning_context: list,
144+
chat_template_kwargs: dict):
145+
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
146+
"nano-v3", chat_template_kwargs)
147+
for i, delta_text in enumerate(delta_texts):
148+
result = reasoning_parser.parse_delta(delta_text)
149+
print(f"delta_text: {delta_text}, result: {result}")
150+
assert result.content == content[i]
151+
assert result.reasoning_content == reasoning_context[i]

0 commit comments

Comments
 (0)