11from abc import ABC , abstractmethod
22from dataclasses import dataclass
3- from typing import Dict , Optional
3+ from typing import Type
44
55
66@dataclass
77class ReasoningParserResult :
8-
9- def __init__ (self ,
10- in_reasoning : bool ,
11- content : Optional [str ] = None ,
12- reasoning_content : Optional [str ] = None ):
13- self .in_reasoning = in_reasoning
14- self .content = content
15- self .reasoning_content = reasoning_content
8+ content : str = ""
9+ reasoning_content : str = ""
1610
1711
1812class BaseReasoningParser (ABC ):
@@ -34,62 +28,99 @@ class DeepSeekR1Parser(BaseReasoningParser):
3428 treat all the text before the </think> tag as `reasoning_content` and the text after as `content`.
3529 """
3630
37- def __init__ (self ):
31+ def __init__ (self , reasoning_at_start : bool = False ) -> None :
32+ self .reasoning_start = "<think>"
3833 self .reasoning_end = "</think>"
39- self .in_reasoning = True
34+ self .reasoning_at_start = reasoning_at_start
35+ self .in_reasoning = self .reasoning_at_start
36+ self ._buffer = ""
4037
4138 def _create_reasoning_end_result (self , content : str ,
4239 reasoning_content : str ):
4340 if len (content ) == 0 :
4441 reasoning_parser_result = ReasoningParserResult (
45- True , reasoning_content = reasoning_content )
42+ reasoning_content = reasoning_content )
4643 elif len (reasoning_content ) == 0 :
47- reasoning_parser_result = ReasoningParserResult (False ,
48- content = content )
44+ reasoning_parser_result = ReasoningParserResult (content = content )
4945 else :
5046 reasoning_parser_result = ReasoningParserResult (
51- False , content = content , reasoning_content = reasoning_content )
47+ content = content , reasoning_content = reasoning_content )
5248 return reasoning_parser_result
5349
5450 def parse (self , text : str ) -> ReasoningParserResult :
55- if self .reasoning_end not in text :
56- return ReasoningParserResult (True , reasoning_content = text )
57-
58- splits = text .split (self .reasoning_end , maxsplit = 1 )
59- reasoning_content = splits [0 ]
60- content = splits [1 ]
61-
62- reasoning_parser_result = self ._create_reasoning_end_result (
63- content , reasoning_content )
64- return reasoning_parser_result
51+ if not self .reasoning_at_start :
52+ splits = text .partition (self .reasoning_start )
53+ if splits [1 ] == "" :
54+ # no reasoning start tag found
55+ return ReasoningParserResult (content = text )
56+ # reasoning start tag found
57+ # text before reasoning start tag is dropped
58+ text = splits [2 ]
59+ splits = text .partition (self .reasoning_end )
60+ reasoning_content , content = splits [0 ], splits [2 ]
61+ return ReasoningParserResult (content = content ,
62+ reasoning_content = reasoning_content )
6563
6664 def parse_delta (self , delta_text : str ) -> ReasoningParserResult :
67- if self .in_reasoning and self .reasoning_end in delta_text :
65+ self ._buffer += delta_text
66+ delta_text = self ._buffer
67+ reasoning_content = None
68+ content = None
69+ if (self .reasoning_start .startswith (delta_text )
70+ or self .reasoning_end .startswith (delta_text )):
71+ # waiting for more text to determine if it's a reasoning start or end tag
72+ return ReasoningParserResult ()
73+
74+ if not self .in_reasoning :
75+ begin_idx = delta_text .find (self .reasoning_start )
76+ if begin_idx == - 1 :
77+ self ._buffer = ""
78+ return ReasoningParserResult (content = delta_text )
79+ self .in_reasoning = True
80+ # set reasoning_content, will be processed by the next block
81+ reasoning_content = delta_text [begin_idx +
82+ len (self .reasoning_start ):]
83+
84+ if self .in_reasoning :
85+ delta_text = reasoning_content if reasoning_content is not None else delta_text
6886 end_idx = delta_text .find (self .reasoning_end )
87+ if end_idx == - 1 :
88+ last_idx = delta_text .rfind (self .reasoning_end [0 ])
89+ if last_idx != - 1 and self .reasoning_end .startswith (
90+ delta_text [last_idx :]):
91+ self ._buffer = delta_text [last_idx :]
92+ reasoning_content = delta_text [:last_idx ]
93+ else :
94+ self ._buffer = ""
95+ reasoning_content = delta_text
96+ return ReasoningParserResult (
97+ reasoning_content = reasoning_content )
6998 reasoning_content = delta_text [:end_idx ]
7099 content = delta_text [end_idx + len (self .reasoning_end ):]
71- reasoning_parser_result = self ._create_reasoning_end_result (
72- content , reasoning_content )
73100 self .in_reasoning = False
74- return reasoning_parser_result
75-
76- if self .in_reasoning :
77- return ReasoningParserResult (self .in_reasoning ,
78- reasoning_content = delta_text )
79-
80- # not self.in_reasoning:
81- return ReasoningParserResult (self .in_reasoning , content = delta_text )
101+ self ._buffer = ""
102+ return ReasoningParserResult (content = content ,
103+ reasoning_content = reasoning_content )
104+ raise RuntimeError (
105+ "Unreachable code reached in `DeepSeekR1Parser.parse_delta`" )
82106
83107
84108class ReasoningParserFactory :
85- parsers : Dict [str , BaseReasoningParser ] = {
109+ parsers : dict [str , Type [ BaseReasoningParser ] ] = {
86110 "deepseek-r1" : DeepSeekR1Parser ,
111+ "qwen3" : DeepSeekR1Parser ,
87112 }
88113
89114 @staticmethod
90115 def create_reasoning_parser (reasoning_parser : str ) -> BaseReasoningParser :
91- if reasoning_parser not in ReasoningParserFactory .parsers :
92- raise ValueError (f"Invalid reasoning_parser: { reasoning_parser } " )
93- reasoning_parser_class = ReasoningParserFactory .parsers .get (
94- reasoning_parser .lower ())
95- return reasoning_parser_class ()
116+ try :
117+ reasoning_parser_class = ReasoningParserFactory .parsers [
118+ reasoning_parser .lower ()]
119+ if reasoning_parser == "deepseek-r1" :
120+ return reasoning_parser_class (reasoning_at_start = True )
121+ return reasoning_parser_class ()
122+ except KeyError as e :
123+ raise ValueError (
124+ f"Invalid reasoning parser: { reasoning_parser } \n "
125+ f"Supported parsers: { list (ReasoningParserFactory .parsers .keys ())} "
126+ ) from e
0 commit comments