diff --git a/engines/python/setup/djl_python/lmi_vllm/request_response_utils.py b/engines/python/setup/djl_python/lmi_vllm/request_response_utils.py index bb050cc6d..0d0168845 100644 --- a/engines/python/setup/djl_python/lmi_vllm/request_response_utils.py +++ b/engines/python/setup/djl_python/lmi_vllm/request_response_utils.py @@ -55,6 +55,7 @@ def convert_lmi_schema_to_completion_request( parameters = payload.get("parameters", {}) completion_dict = { + "model": payload.pop("model"), "prompt": payload.pop("inputs"), "max_tokens": parameters.pop("max_new_tokens", 30), "echo": parameters.pop("return_full_text", False), diff --git a/engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py b/engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py index 792e462ab..f3d216561 100644 --- a/engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py +++ b/engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py @@ -165,6 +165,8 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest: logging.info( f"Using LoRA request: {lora_request.lora_name} (ID: {lora_request.lora_int_id})" ) + # Set the model field to the adapter name so vLLM's _maybe_get_adapters() can extract it + decoded_payload["model"] = adapter_name # completions request if "prompt" in decoded_payload: @@ -238,22 +240,9 @@ async def inference( "", error=f"Input parsing failed: {str(e)}", code=424) return output - if processed_request.lora_request: - original_add_request = self.vllm_engine.add_request - - async def add_request_with_lora(*args, **kwargs): - kwargs['lora_request'] = processed_request.lora_request - return await original_add_request(*args, **kwargs) - - self.vllm_engine.add_request = add_request_with_lora - try: - response = await processed_request.inference_invoker( - processed_request.vllm_request) - finally: - self.vllm_engine.add_request = original_add_request - else: - response = await processed_request.inference_invoker( - processed_request.vllm_request) + # vLLM will extract the adapter from the request object via _maybe_get_adapters() + response = await processed_request.inference_invoker( + processed_request.vllm_request) if isinstance(response, types.AsyncGeneratorType): # Apply custom formatter to streaming response