Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def parse_chat_completions_request_vllm(
):

tool_parser = rolling_batch.get_tool_parser()
reasoning_parser = rolling_batch.get_reasoning_parser()
model = input_map.pop("model", "lmi")
chat_params = ChatCompletionRequest(**input_map, model=model)

Expand Down Expand Up @@ -90,6 +91,7 @@ def parse_chat_completions_request_vllm(
"request_prompts": request_prompt,
"engine_prompt": engine_prompt,
"tool_parser": tool_parser,
"reasoning_parser": reasoning_parser,
"chat_params": chat_params,
}
return input_text, params
Expand Down
39 changes: 37 additions & 2 deletions engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
parameters = request_output.input.parameters
chat_params = parameters.get("chat_params")
tool_parser = parameters.get("tool_parser")
reasoning_parser = parameters.get("reasoning_parser")
best_sequence = request_output.sequences[
request_output.best_sequence_index]
generated_text = get_generated_text(best_sequence, request_output)
Expand All @@ -301,7 +302,24 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
if chat_params and chat_params.tool_choice and type(

if reasoning_parser:
reasoning_content, content = (
reasoning_parser.extract_reasoning_content(generated_text,
request=chat_params))

if reasoning_content:
choice = {
"index": 0,
"message": {
"role": "assistant",
"content": content,
},
"reasoning_content": reasoning_content,
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
elif chat_params and chat_params.tool_choice and type(
chat_params.tool_choice
).__name__ == "ChatCompletionNamedToolChoiceParam":
tool_calls = [{
Expand Down Expand Up @@ -386,6 +404,7 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput):
parameters = request_output.input.parameters
chat_params = parameters.get("chat_params")
tool_parser = parameters.get("tool_parser")
reasoning_parser = parameters.get("reasoning_parser")
best_sequence = request_output.sequences[
request_output.best_sequence_index]
next_token, index, first_token, last_token = best_sequence.get_next_token()
Expand All @@ -396,7 +415,23 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput):

created = int(time.time())

if chat_params and chat_params.tool_choice and type(
if reasoning_parser:
current_text = get_generated_text(best_sequence, request_output)
previous_text = current_text[0:-len(next_token.text)]
current_token_ids = [t.id for t in best_sequence.tokens]
previous_token_ids = current_token_ids[:-1]
delta = reasoning_parser.extract_reasoning_content_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=next_token.text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=[next_token.id],
)
if delta is None:
return ""
delta = delta.model_dump(exclude_unset=True)
elif chat_params and chat_params.tool_choice and type(
chat_params.tool_choice
).__name__ == "ChatCompletionNamedToolChoiceParam":
tool_calls = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class VllmRbProperties(Properties):
enable_auto_tool_choice: bool = False
tool_call_parser: Optional[str] = None

# Reasoning properties
enable_reasoning: bool = False
reasoning_parser: Optional[str] = None

# Neuron vLLM properties
device: str = 'auto'
preloaded_model: Optional[Any] = None
Expand Down Expand Up @@ -129,6 +133,18 @@ def validate_tool_call_parser(self):
f"(chose from {{ {','.join(valid_tool_parses)} }})")
return self

@model_validator(mode='after')
def validate_reasoning_parser(self):
if self.enable_reasoning:
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys(
)
if self.reasoning_parser not in valid_reasoning_parses:
raise ValueError(
f"Invalid reasoning parser: {self.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
return self

@field_validator('override_neuron_config', mode="before")
def validate_override_neuron_config(cls, val):
if isinstance(val, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def get_tool_parser(self):
"""
return None

def get_reasoning_parser(self):
"""
:return: the reasoning parser if available
"""
return None

@abstractmethod
def inference(self, new_requests: List[Request]) -> List:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.lora_requests = {}
self.is_mistral_tokenizer = args.tokenizer_mode == 'mistral'
self.tool_parser = None
self.reasoning_parser = None
if self.vllm_configs.enable_auto_tool_choice:
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
try:
Expand All @@ -61,6 +62,15 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.engine.tokenizer.tokenizer)
except Exception as e:
raise TypeError("Error in tool parser creation.") from e
if self.vllm_configs.enable_reasoning:
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
try:
self.reasoning_parser = ReasoningParserManager.get_reasoning_parser(
self.vllm_configs.reasoning_parser)
self.reasoning_parser = self.reasoning_parser(
self.engine.tokenizer.tokenizer)
except Exception as e:
raise TypeError("Error in reasoning parser creation.") from e

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand All @@ -77,6 +87,9 @@ def use_vllm_chat_completions(self):
def get_tool_parser(self):
return self.tool_parser

def get_reasoning_parser(self):
return self.reasoning_parser

def get_chat_template(self):
if self.is_mistral_tokenizer:
# Mistral tokenizer chat template cannot be overridden
Expand Down
30 changes: 28 additions & 2 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,12 @@ def get_model_name():
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "TheBloke/Llama-2-7B-Chat-fp16",
}
},
"deepseek-r1-distill-qwen-1-5b": {
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
},
}

vllm_tool_model_spec = {
Expand Down Expand Up @@ -1423,6 +1428,24 @@ def batch_generation_tool(batch_size):
return data[:batch_size]


def batch_generation_reasoning(batch_size):
messages = [
[{
"role": "user",
"content": "9.11 and 9.8, which is greater?"
}],
[{
"role": "user",
"content": "How many Rs are there in the word 'strawberry'?"
}],
]

if batch_size > len(messages):
# dynamically extend to support larger bs by repetition
messages *= math.ceil(batch_size / len(messages))
return messages[:batch_size]


def t5_batch_generation(batch_size):
input_sentences = [
"translate English to German: The house is wonderful.",
Expand Down Expand Up @@ -1667,7 +1690,10 @@ def test_handler_rolling_batch_chat(model, model_spec):
check_worker_number(spec["worker"])
stream_values = spec.get("stream", [False, True])
# dryrun phase
req = {"messages": batch_generation_chat(1)[0]}
if spec.get("enable_reasoning", False):
req = {"messages": batch_generation_reasoning(1)[0]}
else:
req = {"messages": batch_generation_chat(1)[0]}
seq_length = 100
req["max_tokens"] = seq_length
req["logprobs"] = True
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,13 @@
"option.enable_auto_tool_choice": True,
"option.tool_call_parser": "mistral",
},
"deepseek-r1-distill-qwen-1-5b": {
"option.model_id": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"option.tensor_parallel_degree": 1,
"option.max_rolling_batch_size": 4,
"option.enable_reasoning": True,
"option.reasoning_parser": "deepseek_r1",
},
}

vllm_neo_model_list = {
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,12 @@ def test_mistral_7b_instruct_v03_tool(self):
r.launch()
client.run("vllm_tool mistral-7b-instruct-v03-tool".split())

def test_deepseek_r1_distill_qwen_1_5b(self):
with Runner('lmi', 'deepseek-r1-distill-qwen-1-5b') as r:
prepare.build_vllm_model("deepseek-r1-distill-qwen-1-5b")
r.launch()
client.run("vllm_chat deepseek-r1-distill-qwen-1-5b".split())


@pytest.mark.vllm
@pytest.mark.lora
Expand Down