Skip to content

Commit 40588c3

Browse files
authored
[python] Support reasoning content (#2722)
1 parent 4ad8ce3 commit 40588c3

File tree

8 files changed

+115
-4
lines changed

8 files changed

+115
-4
lines changed

engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def parse_chat_completions_request_vllm(
3838
):
3939

4040
tool_parser = rolling_batch.get_tool_parser()
41+
reasoning_parser = rolling_batch.get_reasoning_parser()
4142
model = input_map.pop("model", "lmi")
4243
chat_params = ChatCompletionRequest(**input_map, model=model)
4344

@@ -90,6 +91,7 @@ def parse_chat_completions_request_vllm(
9091
"request_prompts": request_prompt,
9192
"engine_prompt": engine_prompt,
9293
"tool_parser": tool_parser,
94+
"reasoning_parser": reasoning_parser,
9395
"chat_params": chat_params,
9496
}
9597
return input_text, params

engines/python/setup/djl_python/output_formatter.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
284284
parameters = request_output.input.parameters
285285
chat_params = parameters.get("chat_params")
286286
tool_parser = parameters.get("tool_parser")
287+
reasoning_parser = parameters.get("reasoning_parser")
287288
best_sequence = request_output.sequences[
288289
request_output.best_sequence_index]
289290
generated_text = get_generated_text(best_sequence, request_output)
@@ -301,7 +302,24 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
301302
"logprobs": None,
302303
"finish_reason": best_sequence.finish_reason,
303304
}
304-
if chat_params and chat_params.tool_choice and type(
305+
306+
if reasoning_parser:
307+
reasoning_content, content = (
308+
reasoning_parser.extract_reasoning_content(generated_text,
309+
request=chat_params))
310+
311+
if reasoning_content:
312+
choice = {
313+
"index": 0,
314+
"message": {
315+
"role": "assistant",
316+
"content": content,
317+
},
318+
"reasoning_content": reasoning_content,
319+
"logprobs": None,
320+
"finish_reason": best_sequence.finish_reason,
321+
}
322+
elif chat_params and chat_params.tool_choice and type(
305323
chat_params.tool_choice
306324
).__name__ == "ChatCompletionNamedToolChoiceParam":
307325
tool_calls = [{
@@ -386,6 +404,7 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput):
386404
parameters = request_output.input.parameters
387405
chat_params = parameters.get("chat_params")
388406
tool_parser = parameters.get("tool_parser")
407+
reasoning_parser = parameters.get("reasoning_parser")
389408
best_sequence = request_output.sequences[
390409
request_output.best_sequence_index]
391410
next_token, index, first_token, last_token = best_sequence.get_next_token()
@@ -396,7 +415,23 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput):
396415

397416
created = int(time.time())
398417

399-
if chat_params and chat_params.tool_choice and type(
418+
if reasoning_parser:
419+
current_text = get_generated_text(best_sequence, request_output)
420+
previous_text = current_text[0:-len(next_token.text)]
421+
current_token_ids = [t.id for t in best_sequence.tokens]
422+
previous_token_ids = current_token_ids[:-1]
423+
delta = reasoning_parser.extract_reasoning_content_streaming(
424+
previous_text=previous_text,
425+
current_text=current_text,
426+
delta_text=next_token.text,
427+
previous_token_ids=previous_token_ids,
428+
current_token_ids=current_token_ids,
429+
delta_token_ids=[next_token.id],
430+
)
431+
if delta is None:
432+
return ""
433+
delta = delta.model_dump(exclude_unset=True)
434+
elif chat_params and chat_params.tool_choice and type(
400435
chat_params.tool_choice
401436
).__name__ == "ChatCompletionNamedToolChoiceParam":
402437
tool_calls = [{

engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ class VllmRbProperties(Properties):
7474
enable_auto_tool_choice: bool = False
7575
tool_call_parser: Optional[str] = None
7676

77+
# Reasoning properties
78+
enable_reasoning: bool = False
79+
reasoning_parser: Optional[str] = None
80+
7781
# Neuron vLLM properties
7882
device: str = 'auto'
7983
preloaded_model: Optional[Any] = None
@@ -129,6 +133,18 @@ def validate_tool_call_parser(self):
129133
f"(chose from {{ {','.join(valid_tool_parses)} }})")
130134
return self
131135

136+
@model_validator(mode='after')
137+
def validate_reasoning_parser(self):
138+
if self.enable_reasoning:
139+
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
140+
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys(
141+
)
142+
if self.reasoning_parser not in valid_reasoning_parses:
143+
raise ValueError(
144+
f"Invalid reasoning parser: {self.reasoning_parser} "
145+
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
146+
return self
147+
132148
@field_validator('override_neuron_config', mode="before")
133149
def validate_override_neuron_config(cls, val):
134150
if isinstance(val, str):

engines/python/setup/djl_python/rolling_batch/rolling_batch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ def get_tool_parser(self):
126126
"""
127127
return None
128128

129+
def get_reasoning_parser(self):
130+
"""
131+
:return: the reasoning parser if available
132+
"""
133+
return None
134+
129135
@abstractmethod
130136
def inference(self, new_requests: List[Request]) -> List:
131137
"""

engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, model_id_or_path: str, properties: dict,
5252
self.lora_requests = {}
5353
self.is_mistral_tokenizer = args.tokenizer_mode == 'mistral'
5454
self.tool_parser = None
55+
self.reasoning_parser = None
5556
if self.vllm_configs.enable_auto_tool_choice:
5657
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
5758
try:
@@ -61,6 +62,15 @@ def __init__(self, model_id_or_path: str, properties: dict,
6162
self.engine.tokenizer.tokenizer)
6263
except Exception as e:
6364
raise TypeError("Error in tool parser creation.") from e
65+
if self.vllm_configs.enable_reasoning:
66+
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
67+
try:
68+
self.reasoning_parser = ReasoningParserManager.get_reasoning_parser(
69+
self.vllm_configs.reasoning_parser)
70+
self.reasoning_parser = self.reasoning_parser(
71+
self.engine.tokenizer.tokenizer)
72+
except Exception as e:
73+
raise TypeError("Error in reasoning parser creation.") from e
6474

6575
def get_tokenizer(self):
6676
return self.engine.tokenizer.tokenizer
@@ -77,6 +87,9 @@ def use_vllm_chat_completions(self):
7787
def get_tool_parser(self):
7888
return self.tool_parser
7989

90+
def get_reasoning_parser(self):
91+
return self.reasoning_parser
92+
8093
def get_chat_template(self):
8194
if self.is_mistral_tokenizer:
8295
# Mistral tokenizer chat template cannot be overridden

tests/integration/llm/client.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,12 @@ def get_model_name():
606606
"batch_size": [1, 4],
607607
"seq_length": [256],
608608
"tokenizer": "TheBloke/Llama-2-7B-Chat-fp16",
609-
}
609+
},
610+
"deepseek-r1-distill-qwen-1-5b": {
611+
"batch_size": [1, 4],
612+
"seq_length": [256],
613+
"tokenizer": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
614+
},
610615
}
611616

612617
vllm_tool_model_spec = {
@@ -1423,6 +1428,24 @@ def batch_generation_tool(batch_size):
14231428
return data[:batch_size]
14241429

14251430

1431+
def batch_generation_reasoning(batch_size):
1432+
messages = [
1433+
[{
1434+
"role": "user",
1435+
"content": "9.11 and 9.8, which is greater?"
1436+
}],
1437+
[{
1438+
"role": "user",
1439+
"content": "How many Rs are there in the word 'strawberry'?"
1440+
}],
1441+
]
1442+
1443+
if batch_size > len(messages):
1444+
# dynamically extend to support larger bs by repetition
1445+
messages *= math.ceil(batch_size / len(messages))
1446+
return messages[:batch_size]
1447+
1448+
14261449
def t5_batch_generation(batch_size):
14271450
input_sentences = [
14281451
"translate English to German: The house is wonderful.",
@@ -1667,7 +1690,10 @@ def test_handler_rolling_batch_chat(model, model_spec):
16671690
check_worker_number(spec["worker"])
16681691
stream_values = spec.get("stream", [False, True])
16691692
# dryrun phase
1670-
req = {"messages": batch_generation_chat(1)[0]}
1693+
if spec.get("enable_reasoning", False):
1694+
req = {"messages": batch_generation_reasoning(1)[0]}
1695+
else:
1696+
req = {"messages": batch_generation_chat(1)[0]}
16711697
seq_length = 100
16721698
req["max_tokens"] = seq_length
16731699
req["logprobs"] = True

tests/integration/llm/prepare.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,13 @@
10861086
"option.enable_auto_tool_choice": True,
10871087
"option.tool_call_parser": "mistral",
10881088
},
1089+
"deepseek-r1-distill-qwen-1-5b": {
1090+
"option.model_id": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
1091+
"option.tensor_parallel_degree": 1,
1092+
"option.max_rolling_batch_size": 4,
1093+
"option.enable_reasoning": True,
1094+
"option.reasoning_parser": "deepseek_r1",
1095+
},
10891096
}
10901097

10911098
vllm_neo_model_list = {

tests/integration/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,12 @@ def test_mistral_7b_instruct_v03_tool(self):
643643
r.launch()
644644
client.run("vllm_tool mistral-7b-instruct-v03-tool".split())
645645

646+
def test_deepseek_r1_distill_qwen_1_5b(self):
647+
with Runner('lmi', 'deepseek-r1-distill-qwen-1-5b') as r:
648+
prepare.build_vllm_model("deepseek-r1-distill-qwen-1-5b")
649+
r.launch()
650+
client.run("vllm_chat deepseek-r1-distill-qwen-1-5b".split())
651+
646652

647653
@pytest.mark.vllm
648654
@pytest.mark.lora

0 commit comments

Comments
 (0)