Skip to content

Commit 234ef92

Browse files
committed
add model status in vl
1 parent 9082f62 commit 234ef92

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

fastdeploy/input/ernie4_5_processor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def process_request_dict(self, request, max_model_len=None):
232232
request["top_p"] = _SAMPLING_EPS
233233
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
234234
request["enable_thinking"] = True
235-
235+
if self.reasoning_parser:
236+
request["model_status"] = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
236237
data_processor_logger.info(f"Processed request dict: {request}")
237238
return request
238239

@@ -246,6 +247,7 @@ def process_response(self, response_dict, **kwargs):
246247
Returns:
247248
Dict: response contain text fields
248249
"""
250+
model_status = kwargs.get("model_status")
249251
req_id = response_dict.request_id
250252
token_ids = response_dict.outputs.token_ids
251253

@@ -254,7 +256,9 @@ def process_response(self, response_dict, **kwargs):
254256
token_ids = token_ids[:-1]
255257
full_text = self.tokenizer.decode(token_ids)
256258
if self.reasoning_parser:
257-
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
259+
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
260+
full_text, response_dict, model_status
261+
)
258262
response_dict.outputs.text = text
259263
response_dict.outputs.reasoning_content = reasoning_content
260264
else:
@@ -296,6 +300,7 @@ def process_response_dict_normal(self, response_dict, **kwargs):
296300
Dict: response contain text fields
297301
"""
298302
enable_thinking = kwargs.get("enable_thinking")
303+
model_status = kwargs.get("model_status")
299304
token_ids = response_dict["outputs"]["token_ids"]
300305
is_end = response_dict["finished"]
301306
req_id = response_dict["request_id"]
@@ -308,7 +313,9 @@ def process_response_dict_normal(self, response_dict, **kwargs):
308313
if self.reasoning_parser and (
309314
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
310315
):
311-
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
316+
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
317+
full_text, response_dict, model_status
318+
)
312319
response_dict["outputs"]["text"] = text
313320
response_dict["outputs"]["reasoning_content"] = reasoning_content
314321
else:
@@ -335,6 +342,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
335342
Dict: response contain text fields
336343
"""
337344
enable_thinking = kwargs.get("enable_thinking")
345+
model_status = kwargs.get("model_status")
338346
is_end = response_dict["finished"]
339347
req_id = response_dict["request_id"]
340348
token_ids = response_dict["outputs"]["token_ids"]
@@ -354,6 +362,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
354362
previous_token_ids,
355363
previous_token_ids + token_ids,
356364
token_ids,
365+
model_status,
357366
)
358367
response_dict["outputs"]["delta_message"] = reasoning_delta_message
359368
if self.tool_parser_obj:

fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ def process_request_dict(self, request, max_model_len=None):
255255
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
256256
data_processor_logger.info(f"Processed request {request}")
257257

258+
if self.reasoning_parser is not None:
259+
request["model_status"] = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
260+
258261
return request
259262

260263
def append_completion_tokens(self, multimodal_inputs, completion_token_ids):

fastdeploy/reasoning/ernie_vl_reasoning_parsers.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class ErnieVLReasoningParser(ReasoningParser):
3535

3636
def __init__(self, tokenizer):
3737
super().__init__(tokenizer)
38+
self.think_start_token = "</think>"
3839
self.think_end_token = "</think>"
3940

4041
if not self.model_tokenizer:
@@ -45,10 +46,28 @@ def __init__(self, tokenizer):
4546
self.think_end_token_id = self.vocab.get(self.think_end_token)
4647
if self.think_end_token_id is None:
4748
raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!")
49+
self.think_start_token_id = self.vocab.get(self.think_start_token)
4850

4951
def is_reasoning_end(self, input_ids: list[int]) -> bool:
5052
return self.think_end_token_id in input_ids
5153

54+
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
55+
for i in range(len(prompt_token_ids) - 1, -1, -1):
56+
if prompt_token_ids[i] in [self.think_end_token_id, self.think_start_token_id]:
57+
return prompt_token_ids[i]
58+
return -1
59+
60+
def get_model_status(self, prompt_token_ids: list[int]):
61+
special_token_id = self.find_last_special_token(prompt_token_ids)
62+
if special_token_id == -1:
63+
return "responding"
64+
if special_token_id == self.think_end_token_id:
65+
return "responding"
66+
if self.think_start_token_id == special_token_id:
67+
return "thinking"
68+
69+
return "responding"
70+
5271
def extract_reasoning_content_streaming(
5372
self,
5473
previous_text: str,
@@ -57,6 +76,7 @@ def extract_reasoning_content_streaming(
5776
previous_token_ids: Sequence[int],
5877
current_token_ids: Sequence[int],
5978
delta_token_ids: Sequence[int],
79+
model_status: str,
6080
) -> Union[DeltaMessage, None]:
6181
"""
6282
Extract reasoning content from a delta message.
@@ -80,7 +100,10 @@ def extract_reasoning_content_streaming(
80100
return DeltaMessage(reasoning_content=delta_text)
81101

82102
def extract_reasoning_content(
83-
self, model_output: str, request: ChatCompletionRequest
103+
self,
104+
model_output: str,
105+
request: ChatCompletionRequest,
106+
model_status: str,
84107
) -> tuple[Optional[str], Optional[str]]:
85108
"""
86109
Extract reasoning content from the model output.
@@ -94,9 +117,11 @@ def extract_reasoning_content(
94117
"""
95118

96119
# Check if the model output contains the </think> tokens.
97-
if self.think_end_token not in model_output:
120+
if model_status == "thinking":
121+
if self.think_end_token not in model_output:
122+
return model_output, ""
123+
reasoning_content, _, content = model_output.partition(self.think_end_token)
124+
final_content = content or ""
125+
return reasoning_content, final_content
126+
else:
98127
return "", model_output
99-
reasoning_content, _, content = model_output.partition(self.think_end_token)
100-
101-
final_content = content or ""
102-
return reasoning_content, final_content

0 commit comments

Comments
 (0)