@@ -60,6 +60,7 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob
60
60
self .decode_status = dict ()
61
61
self .tool_parser_dict = dict ()
62
62
self .thinking_parser_dict = dict ()
63
+ self .model_status_dict = dict ()
63
64
self ._load_tokenizer ()
64
65
data_processor_logger .info (
65
66
f"tokenizer information: bos_token is { self .tokenizer .bos_token } \
@@ -154,6 +155,12 @@ def process_request(self, request, max_model_len=None, **kwargs):
154
155
request .set ("top_p" , _SAMPLING_EPS )
155
156
if self .reasoning_parser and self .reasoning_parser .__class__ .__name__ == "ErnieX1ReasoningParser" :
156
157
request .enable_thinking = True
158
+ if self .reasoning_parser :
159
+ self .model_status_dict [request .request_id ] = self .reasoning_parser .get_model_status (
160
+ request .prompt_token_ids
161
+ )
162
+ if self .model_status_dict [request .request_id ] == "think_start" :
163
+ request .enable_thinking = True
157
164
158
165
data_processor_logger .info (f"Processed request: { request } " )
159
166
return request
@@ -233,8 +240,8 @@ def process_request_dict(self, request, max_model_len=None):
233
240
if self .reasoning_parser and self .reasoning_parser .__class__ .__name__ == "ErnieX1ReasoningParser" :
234
241
request ["enable_thinking" ] = True
235
242
if self .reasoning_parser :
236
- request [ "model_status " ] = self .reasoning_parser .get_model_status (request ["prompt_token_ids" ])
237
- if request [ "model_status " ] == "think_start" :
243
+ self . model_status_dict [ "request_id " ] = self .reasoning_parser .get_model_status (request ["prompt_token_ids" ])
244
+ if self . model_status_dict [ "request_id " ] == "think_start" :
238
245
request ["enable_thinking" ] = True
239
246
data_processor_logger .info (f"Processed request dict: { request } " )
240
247
return request
@@ -274,6 +281,8 @@ def process_response(self, response_dict, **kwargs):
274
281
data_processor_logger .info (f"req_id:{ req_id } , token_ids: { token_ids } " )
275
282
if response_dict .outputs .text == "" and response_dict .outputs .reasoning_content == "" :
276
283
return None
284
+ if req_id in self .model_status_dict :
285
+ del self .model_status_dict [req_id ]
277
286
return response_dict
278
287
279
288
def process_response_dict (self , response_dict , stream , ** kwargs ):
@@ -302,7 +311,6 @@ def process_response_dict_normal(self, response_dict, **kwargs):
302
311
Dict: response contain text fields
303
312
"""
304
313
enable_thinking = kwargs .get ("enable_thinking" )
305
- model_status = kwargs .get ("model_status" )
306
314
token_ids = response_dict ["outputs" ]["token_ids" ]
307
315
is_end = response_dict ["finished" ]
308
316
req_id = response_dict ["request_id" ]
@@ -317,7 +325,7 @@ def process_response_dict_normal(self, response_dict, **kwargs):
317
325
enable_thinking or self .reasoning_parser .__class__ .__name__ == "ErnieX1ReasoningParser"
318
326
):
319
327
reasoning_content , text = self .reasoning_parser .extract_reasoning_content (
320
- full_text , response_dict , model_status
328
+ full_text , response_dict , self . model_status_dict . get ( req_id )
321
329
)
322
330
response_dict ["outputs" ]["text" ] = text
323
331
response_dict ["outputs" ]["reasoning_content" ] = reasoning_content
@@ -330,6 +338,8 @@ def process_response_dict_normal(self, response_dict, **kwargs):
330
338
response_dict ["outputs" ]["raw_prediction" ] = full_text
331
339
data_processor_logger .info (f"req_id:{ req_id } , decode_status: { self .decode_status [req_id ]} " )
332
340
del self .decode_status [req_id ]
341
+ if req_id in self .model_status_dict :
342
+ del self .model_status_dict [req_id ]
333
343
return response_dict
334
344
335
345
def process_response_dict_streaming (self , response_dict , ** kwargs ):
@@ -343,7 +353,6 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
343
353
Dict: response contain text fields
344
354
"""
345
355
enable_thinking = kwargs .get ("enable_thinking" )
346
- model_status = kwargs .get ("model_status" )
347
356
is_end = response_dict ["finished" ]
348
357
req_id = response_dict ["request_id" ]
349
358
token_ids = response_dict ["outputs" ]["token_ids" ]
@@ -363,7 +372,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
363
372
previous_token_ids ,
364
373
previous_token_ids + token_ids ,
365
374
token_ids ,
366
- model_status ,
375
+ self . model_status_dict . get ( req_id ) ,
367
376
)
368
377
response_dict ["outputs" ]["delta_message" ] = reasoning_delta_message
369
378
if self .tool_parser_obj :
@@ -387,6 +396,8 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
387
396
del self .decode_status [req_id ]
388
397
if req_id in self .tool_parser_dict :
389
398
del self .tool_parser_dict [req_id ]
399
+ if req_id in self .model_status_dict :
400
+ del self .model_status_dict [req_id ]
390
401
return response_dict
391
402
392
403
def messages2ids (self , request_or_messages ):
0 commit comments