@@ -232,7 +232,8 @@ def process_request_dict(self, request, max_model_len=None):
232
232
request ["top_p" ] = _SAMPLING_EPS
233
233
if self .reasoning_parser and self .reasoning_parser .__class__ .__name__ == "ErnieX1ReasoningParser" :
234
234
request ["enable_thinking" ] = True
235
-
235
+ if self .reasoning_parser :
236
+ request ["model_status" ] = self .reasoning_parser .get_model_status (request ["prompt_token_ids" ])
236
237
data_processor_logger .info (f"Processed request dict: { request } " )
237
238
return request
238
239
@@ -246,6 +247,7 @@ def process_response(self, response_dict, **kwargs):
246
247
Returns:
247
248
Dict: response contain text fields
248
249
"""
250
+ model_status = kwargs .get ("model_status" )
249
251
req_id = response_dict .request_id
250
252
token_ids = response_dict .outputs .token_ids
251
253
@@ -254,7 +256,9 @@ def process_response(self, response_dict, **kwargs):
254
256
token_ids = token_ids [:- 1 ]
255
257
full_text = self .tokenizer .decode (token_ids )
256
258
if self .reasoning_parser :
257
- reasoning_content , text = self .reasoning_parser .extract_reasoning_content (full_text , response_dict )
259
+ reasoning_content , text = self .reasoning_parser .extract_reasoning_content (
260
+ full_text , response_dict , model_status
261
+ )
258
262
response_dict .outputs .text = text
259
263
response_dict .outputs .reasoning_content = reasoning_content
260
264
else :
@@ -296,6 +300,7 @@ def process_response_dict_normal(self, response_dict, **kwargs):
296
300
Dict: response contain text fields
297
301
"""
298
302
enable_thinking = kwargs .get ("enable_thinking" )
303
+ model_status = kwargs .get ("model_status" )
299
304
token_ids = response_dict ["outputs" ]["token_ids" ]
300
305
is_end = response_dict ["finished" ]
301
306
req_id = response_dict ["request_id" ]
@@ -308,7 +313,9 @@ def process_response_dict_normal(self, response_dict, **kwargs):
308
313
if self .reasoning_parser and (
309
314
enable_thinking or self .reasoning_parser .__class__ .__name__ == "ErnieX1ReasoningParser"
310
315
):
311
- reasoning_content , text = self .reasoning_parser .extract_reasoning_content (full_text , response_dict )
316
+ reasoning_content , text = self .reasoning_parser .extract_reasoning_content (
317
+ full_text , response_dict , model_status
318
+ )
312
319
response_dict ["outputs" ]["text" ] = text
313
320
response_dict ["outputs" ]["reasoning_content" ] = reasoning_content
314
321
else :
@@ -335,6 +342,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
335
342
Dict: response contain text fields
336
343
"""
337
344
enable_thinking = kwargs .get ("enable_thinking" )
345
+ model_status = kwargs .get ("model_status" )
338
346
is_end = response_dict ["finished" ]
339
347
req_id = response_dict ["request_id" ]
340
348
token_ids = response_dict ["outputs" ]["token_ids" ]
@@ -354,6 +362,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
354
362
previous_token_ids ,
355
363
previous_token_ids + token_ids ,
356
364
token_ids ,
365
+ model_status ,
357
366
)
358
367
response_dict ["outputs" ]["delta_message" ] = reasoning_delta_message
359
368
if self .tool_parser_obj :
0 commit comments