Skip to content

Commit bfe6b3e

Browse files
Fix race condition in LoRA streaming requests (#2954) (#2957)
Co-authored-by: Loki <lokravi@amazon.com>
1 parent c3b8633 commit bfe6b3e

File tree

2 files changed

+6
-15
lines changed

2 files changed

+6
-15
lines changed

engines/python/setup/djl_python/lmi_vllm/request_response_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def convert_lmi_schema_to_completion_request(
5555
parameters = payload.get("parameters", {})
5656

5757
completion_dict = {
58+
"model": payload.pop("model"),
5859
"prompt": payload.pop("inputs"),
5960
"max_tokens": parameters.pop("max_new_tokens", 30),
6061
"echo": parameters.pop("return_full_text", False),

engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest:
168168
logging.info(
169169
f"Using LoRA request: {lora_request.lora_name} (ID: {lora_request.lora_int_id})"
170170
)
171+
# Set the model field to the adapter name so vLLM's _maybe_get_adapters() can extract it
172+
decoded_payload["model"] = adapter_name
171173

172174
# completions request
173175
if "prompt" in decoded_payload:
@@ -241,21 +243,9 @@ async def inference(
241243
"", error=f"Input parsing failed: {str(e)}", code=424)
242244
return output
243245

244-
if processed_request.lora_request:
245-
original_add_request = self.vllm_engine.add_request
246-
247-
async def add_request_with_lora(*args, **kwargs):
248-
kwargs['lora_request'] = processed_request.lora_request
249-
return await original_add_request(*args, **kwargs)
250-
251-
self.vllm_engine.add_request = add_request_with_lora
252-
253-
try:
254-
response = await processed_request.inference_invoker(
255-
processed_request.vllm_request)
256-
finally:
257-
if processed_request.lora_request:
258-
self.vllm_engine.add_request = original_add_request
246+
# vLLM will extract the adapter from the request object via _maybe_get_adapters()
247+
response = await processed_request.inference_invoker(
248+
processed_request.vllm_request)
259249

260250
if isinstance(response, types.AsyncGeneratorType):
261251
# Apply custom formatter to streaming response

0 commit comments

Comments
 (0)