@@ -175,6 +175,7 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob
175
175
self .generation_config = None
176
176
177
177
self .decode_status = dict ()
178
+ self .model_status_dict = dict ()
178
179
self .tool_parser_dict = dict ()
179
180
self .tokenizer = self ._load_tokenizer ()
180
181
data_processor_logger .info (
@@ -266,8 +267,10 @@ def process_request(self, request, max_model_len=None, **kwargs):
266
267
if request .get ("top_p" ) < _SAMPLING_EPS :
267
268
request .set ("top_p" , _SAMPLING_EPS )
268
269
if self .reasoning_parser :
269
- request .model_status = self .reasoning_parser .get_model_status (request .prompt_token_ids )
270
- if request .model_status == "think_start" :
270
+ self .model_status_dict [request .request_id ] = self .reasoning_parser .get_model_status (
271
+ request .prompt_token_ids
272
+ )
273
+ if self .model_status_dict [request .request_id ] == "think_start" :
271
274
request .enable_thinking = True
272
275
273
276
data_processor_logger .info (f"Processed request: { request } " )
@@ -343,6 +346,12 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
343
346
request ["temperature" ] = 1
344
347
if request .get ("top_p" ) < _SAMPLING_EPS :
345
348
request ["top_p" ] = _SAMPLING_EPS
349
+ if self .reasoning_parser :
350
+ self .model_status_dict [request ["request_id" ]] = self .reasoning_parser .get_model_status (
351
+ request ["prompt_token_ids" ]
352
+ )
353
+ if self .model_status_dict [request ["request_id" ]] == "think_start" :
354
+ request ["enable_thinking" ] = True
346
355
347
356
data_processor_logger .info (f"Processed request dict: { request } " )
348
357
return request
@@ -366,21 +375,22 @@ def process_response(self, response_dict, **kwargs):
366
375
if token_ids [- 1 ] == self .tokenizer .eos_token_id :
367
376
token_ids = token_ids [:- 1 ]
368
377
full_text = self .tokenizer .decode (token_ids )
369
-
378
+ response_dict . outputs . text = full_text
370
379
# 模型支持思考,并且支持思考
371
380
if self .reasoning_parser :
372
- reasoning_content , text = self .reasoning_parser .extract_reasoning_content (full_text , response_dict )
381
+ reasoning_content , text = self .reasoning_parser .extract_reasoning_content (
382
+ full_text , response_dict , self .model_status_dict [req_id ]
383
+ )
373
384
response_dict .outputs .text = text
374
385
response_dict .outputs .reasoning_content = reasoning_content
375
- else :
376
- # 模型不支持思考,并且没单独设置enable_thinking为false
377
- response_dict .outputs .text = full_text
378
386
if self .tool_parser_obj :
379
387
tool_parser = self .tool_parser_obj (self .tokenizer )
380
388
tool_call_info = tool_parser .extract_tool_calls (full_text , response_dict )
381
389
if tool_call_info .tools_called :
382
390
response_dict .outputs .tool_calls = tool_call_info .tool_calls
383
391
response_dict .outputs .text = tool_call_info .content
392
+ if req_id in self .model_status_dict :
393
+ del self .model_status_dict [req_id ]
384
394
data_processor_logger .info (f"req_id:{ req_id } , token_ids: { token_ids } " )
385
395
386
396
return response_dict
@@ -395,7 +405,6 @@ def process_response_dict_normal(self, response_dict, **kwargs):
395
405
Returns:
396
406
Dict: response contain text fields
397
407
"""
398
- enable_thinking = kwargs .get ("enable_thinking" )
399
408
token_ids = response_dict ["outputs" ]["token_ids" ]
400
409
is_end = response_dict ["finished" ]
401
410
req_id = response_dict ["request_id" ]
@@ -406,12 +415,13 @@ def process_response_dict_normal(self, response_dict, **kwargs):
406
415
if is_end :
407
416
full_text = previous_texts + delta_text
408
417
response_dict ["outputs" ]["raw_prediction" ] = full_text
409
- if enable_thinking and self .reasoning_parser :
410
- reasoning_content , text = self .reasoning_parser .extract_reasoning_content (full_text , response_dict )
418
+ response_dict ["outputs" ]["text" ] = full_text
419
+ if self .reasoning_parser :
420
+ reasoning_content , text = self .reasoning_parser .extract_reasoning_content (
421
+ full_text , response_dict , self .model_status_dict [req_id ]
422
+ )
411
423
response_dict ["outputs" ]["text" ] = text
412
424
response_dict ["outputs" ]["reasoning_content" ] = reasoning_content
413
- else :
414
- response_dict ["outputs" ]["text" ] = full_text
415
425
if self .tool_parser_obj :
416
426
tool_parser = self .tool_parser_obj (self .tokenizer )
417
427
tool_call_info = tool_parser .extract_tool_calls (full_text , response_dict )
@@ -432,7 +442,6 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
432
442
Returns:
433
443
Dict: response contain text fields
434
444
"""
435
- enable_thinking = kwargs .get ("enable_thinking" )
436
445
is_end = response_dict ["finished" ]
437
446
req_id = response_dict ["request_id" ]
438
447
token_ids = response_dict ["outputs" ]["token_ids" ]
@@ -442,16 +451,15 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
442
451
token_ids = token_ids [:- 1 ]
443
452
delta_text , previous_token_ids , previous_texts = self .ids2tokens (token_ids , req_id )
444
453
response_dict ["outputs" ]["raw_prediction" ] = delta_text
445
- if self .reasoning_parser and (
446
- enable_thinking or self .reasoning_parser .__class__ .__name__ == "ErnieX1ReasoningParser"
447
- ):
454
+ if self .reasoning_parser :
448
455
reasoning_delta_message = self .reasoning_parser .extract_reasoning_content_streaming (
449
456
previous_texts ,
450
457
previous_texts + delta_text ,
451
458
delta_text ,
452
459
previous_token_ids ,
453
460
previous_token_ids + token_ids ,
454
461
token_ids ,
462
+ self .model_status_dict [req_id ],
455
463
)
456
464
response_dict ["outputs" ]["delta_message" ] = reasoning_delta_message
457
465
if self .tool_parser_obj :
@@ -475,6 +483,8 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
475
483
del self .decode_status [req_id ]
476
484
if req_id in self .tool_parser_dict :
477
485
del self .tool_parser_dict [req_id ]
486
+ if req_id in self .model_status_dict :
487
+ del self .model_status_dict [req_id ]
478
488
return response_dict
479
489
480
490
def process_response_dict (self , response_dict , ** kwargs ):
@@ -487,16 +497,12 @@ def process_response_dict(self, response_dict, **kwargs):
487
497
Returns:
488
498
Dict: response contain text fields
489
499
"""
490
- enable_thinking = kwargs .pop ("enable_thinking" , True )
491
- if enable_thinking is None :
492
- enable_thinking = True
493
500
stream = kwargs .get ("stream" , True )
494
501
if stream :
495
- return self .process_response_dict_streaming (response_dict , enable_thinking = enable_thinking , ** kwargs )
502
+ return self .process_response_dict_streaming (response_dict , ** kwargs )
496
503
else :
497
504
return self .process_response_dict_normal (
498
505
response_dict = response_dict ,
499
- enable_thinking = enable_thinking ,
500
506
** kwargs ,
501
507
)
502
508
0 commit comments