diff --git a/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py b/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py index a25473663..8216b6656 100644 --- a/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py +++ b/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py @@ -38,6 +38,7 @@ def parse_chat_completions_request_vllm( ): tool_parser = rolling_batch.get_tool_parser() + reasoning_parser = rolling_batch.get_reasoning_parser() model = input_map.pop("model", "lmi") chat_params = ChatCompletionRequest(**input_map, model=model) @@ -90,6 +91,7 @@ def parse_chat_completions_request_vllm( "request_prompts": request_prompt, "engine_prompt": engine_prompt, "tool_parser": tool_parser, + "reasoning_parser": reasoning_parser, "chat_params": chat_params, } return input_text, params diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index e47bbe812..c86709559 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -284,6 +284,7 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput): parameters = request_output.input.parameters chat_params = parameters.get("chat_params") tool_parser = parameters.get("tool_parser") + reasoning_parser = parameters.get("reasoning_parser") best_sequence = request_output.sequences[ request_output.best_sequence_index] generated_text = get_generated_text(best_sequence, request_output) @@ -301,7 +302,24 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput): "logprobs": None, "finish_reason": best_sequence.finish_reason, } - if chat_params and chat_params.tool_choice and type( + + if reasoning_parser: + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content(generated_text, + request=chat_params)) + + if reasoning_content: + choice = { + "index": 0, + "message": { + "role": "assistant", + "content": content, + }, + "reasoning_content": reasoning_content, + "logprobs": None, + "finish_reason": best_sequence.finish_reason, + } + elif chat_params and chat_params.tool_choice and type( chat_params.tool_choice ).__name__ == "ChatCompletionNamedToolChoiceParam": tool_calls = [{ @@ -386,6 +404,7 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput): parameters = request_output.input.parameters chat_params = parameters.get("chat_params") tool_parser = parameters.get("tool_parser") + reasoning_parser = parameters.get("reasoning_parser") best_sequence = request_output.sequences[ request_output.best_sequence_index] next_token, index, first_token, last_token = best_sequence.get_next_token() @@ -396,7 +415,23 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput): created = int(time.time()) - if chat_params and chat_params.tool_choice and type( + if reasoning_parser: + current_text = get_generated_text(best_sequence, request_output) + previous_text = current_text[0:-len(next_token.text)] + current_token_ids = [t.id for t in best_sequence.tokens] + previous_token_ids = current_token_ids[:-1] + delta = reasoning_parser.extract_reasoning_content_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=next_token.text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=[next_token.id], + ) + if delta is None: + return "" + delta = delta.model_dump(exclude_unset=True) + elif chat_params and chat_params.tool_choice and type( chat_params.tool_choice ).__name__ == "ChatCompletionNamedToolChoiceParam": tool_calls = [{ diff --git a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py index 9488fcd06..bd0206fe8 100644 --- a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py +++ b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py @@ -74,6 +74,10 @@ class VllmRbProperties(Properties): enable_auto_tool_choice: bool = False tool_call_parser: Optional[str] = None + # Reasoning properties + enable_reasoning: bool = False + reasoning_parser: Optional[str] = None + # Neuron vLLM properties device: str = 'auto' preloaded_model: Optional[Any] = None @@ -129,6 +133,18 @@ def validate_tool_call_parser(self): f"(chose from {{ {','.join(valid_tool_parses)} }})") return self + @model_validator(mode='after') + def validate_reasoning_parser(self): + if self.enable_reasoning: + from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager + valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys( + ) + if self.reasoning_parser not in valid_reasoning_parses: + raise ValueError( + f"Invalid reasoning parser: {self.reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + return self + @field_validator('override_neuron_config', mode="before") def validate_override_neuron_config(cls, val): if isinstance(val, str): diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index 764fac394..ee90ad15d 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -126,6 +126,12 @@ def get_tool_parser(self): """ return None + def get_reasoning_parser(self): + """ + :return: the reasoning parser if available + """ + return None + @abstractmethod def inference(self, new_requests: List[Request]) -> List: """ diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 9cd9d567f..76e5c8e0c 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -52,6 +52,7 @@ def __init__(self, model_id_or_path: str, properties: dict, self.lora_requests = {} self.is_mistral_tokenizer = args.tokenizer_mode == 'mistral' self.tool_parser = None + self.reasoning_parser = None if self.vllm_configs.enable_auto_tool_choice: from vllm.entrypoints.openai.tool_parsers import ToolParserManager try: @@ -61,6 +62,15 @@ def __init__(self, model_id_or_path: str, properties: dict, self.engine.tokenizer.tokenizer) except Exception as e: raise TypeError("Error in tool parser creation.") from e + if self.vllm_configs.enable_reasoning: + from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager + try: + self.reasoning_parser = ReasoningParserManager.get_reasoning_parser( + self.vllm_configs.reasoning_parser) + self.reasoning_parser = self.reasoning_parser( + self.engine.tokenizer.tokenizer) + except Exception as e: + raise TypeError("Error in reasoning parser creation.") from e def get_tokenizer(self): return self.engine.tokenizer.tokenizer @@ -77,6 +87,9 @@ def use_vllm_chat_completions(self): def get_tool_parser(self): return self.tool_parser + def get_reasoning_parser(self): + return self.reasoning_parser + def get_chat_template(self): if self.is_mistral_tokenizer: # Mistral tokenizer chat template cannot be overridden diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index 3aa3da0d3..d1a6f7eb2 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -606,7 +606,12 @@ def get_model_name(): "batch_size": [1, 4], "seq_length": [256], "tokenizer": "TheBloke/Llama-2-7B-Chat-fp16", - } + }, + "deepseek-r1-distill-qwen-1-5b": { + "batch_size": [1, 4], + "seq_length": [256], + "tokenizer": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + }, } vllm_tool_model_spec = { @@ -1423,6 +1428,24 @@ def batch_generation_tool(batch_size): return data[:batch_size] +def batch_generation_reasoning(batch_size): + messages = [ + [{ + "role": "user", + "content": "9.11 and 9.8, which is greater?" + }], + [{ + "role": "user", + "content": "How many Rs are there in the word 'strawberry'?" + }], + ] + + if batch_size > len(messages): + # dynamically extend to support larger bs by repetition + messages *= math.ceil(batch_size / len(messages)) + return messages[:batch_size] + + def t5_batch_generation(batch_size): input_sentences = [ "translate English to German: The house is wonderful.", @@ -1667,7 +1690,10 @@ def test_handler_rolling_batch_chat(model, model_spec): check_worker_number(spec["worker"]) stream_values = spec.get("stream", [False, True]) # dryrun phase - req = {"messages": batch_generation_chat(1)[0]} + if spec.get("enable_reasoning", False): + req = {"messages": batch_generation_reasoning(1)[0]} + else: + req = {"messages": batch_generation_chat(1)[0]} seq_length = 100 req["max_tokens"] = seq_length req["logprobs"] = True diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index a820a96c9..49a654a0b 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -1086,6 +1086,13 @@ "option.enable_auto_tool_choice": True, "option.tool_call_parser": "mistral", }, + "deepseek-r1-distill-qwen-1-5b": { + "option.model_id": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "option.tensor_parallel_degree": 1, + "option.max_rolling_batch_size": 4, + "option.enable_reasoning": True, + "option.reasoning_parser": "deepseek_r1", + }, } vllm_neo_model_list = { diff --git a/tests/integration/tests.py b/tests/integration/tests.py index 1269b6132..b4717767a 100644 --- a/tests/integration/tests.py +++ b/tests/integration/tests.py @@ -643,6 +643,12 @@ def test_mistral_7b_instruct_v03_tool(self): r.launch() client.run("vllm_tool mistral-7b-instruct-v03-tool".split()) + def test_deepseek_r1_distill_qwen_1_5b(self): + with Runner('lmi', 'deepseek-r1-distill-qwen-1-5b') as r: + prepare.build_vllm_model("deepseek-r1-distill-qwen-1-5b") + r.launch() + client.run("vllm_chat deepseek-r1-distill-qwen-1-5b".split()) + @pytest.mark.vllm @pytest.mark.lora