Skip to content

Commit f228a8a

Browse files
author
Suma Kasa
committed
Add fixes for vllm nightly build
1 parent 23e8fe8 commit f228a8a

File tree

6 files changed

+47
-24
lines changed

6 files changed

+47
-24
lines changed

engines/python/setup/djl_python/async_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _extract_lora_adapter(raw_request, decoded_payload):
128128
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER)
129129
logging.debug(f"Found adapter in headers: {adapter_name}")
130130
elif "adapter" in decoded_payload:
131-
adapter_name = decoded_payload.pop("adapter")
131+
adapter_name = decoded_payload.get("adapter")
132132
logging.debug(f"Found adapter in payload: {adapter_name}")
133133

134134
return adapter_name

engines/python/setup/djl_python/lmi_vllm/request_response_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212
# the specific language governing permissions and limitations under the License.
1313
import json
14+
import logging
1415
from typing import Callable, Tuple, Union, List, Dict
1516
from vllm.entrypoints.openai.protocol import (
1617
CompletionRequest,
@@ -26,6 +27,8 @@
2627
from djl_python.outputs import Output
2728
from djl_python.async_utils import create_non_stream_output, create_stream_chunk_output
2829

30+
logger = logging.getLogger(__name__)
31+
2932

3033
class ProcessedRequest:
3134

@@ -52,16 +55,21 @@ def __init__(
5255

5356
def convert_lmi_schema_to_completion_request(
5457
payload: dict, ) -> Tuple[CompletionRequest, bool, bool]:
55-
parameters = payload.get("parameters", {})
58+
# Create a copy to avoid mutating the original
59+
parameters = payload.get("parameters", {}).copy()
60+
61+
prompt = payload.get("inputs", "")
62+
if not prompt:
63+
raise ValueError("Input prompt cannot be empty")
5664

5765
completion_dict = {
58-
"prompt": payload.pop("inputs"),
66+
"prompt": prompt,
5967
"max_tokens": parameters.pop("max_new_tokens", 30),
6068
"echo": parameters.pop("return_full_text", False),
6169
"truncate_prompt_tokens": parameters.pop("truncate", None),
6270
"n": parameters.pop("top_n_tokens", 1),
6371
"ignore_eos": parameters.pop("ignore_eos_token", False),
64-
"stream": payload.pop("stream", False),
72+
"stream": payload.get("stream", False),
6573
}
6674
# 1. when details are requested, return token details for the likely tokens (logprobs=1)
6775
# TGI only returns prompt token details when details is also enabled

engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
1111
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212
# the specific language governing permissions and limitations under the License.
13+
import asyncio
14+
import copy
1315
import logging
1416
import os
1517
import types
@@ -23,7 +25,8 @@
2325
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
2426
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
2527
from vllm.entrypoints.openai.serving_models import OpenAIServingModels, BaseModelPath
26-
from vllm.utils import kill_process_tree, AtomicCounter
28+
from vllm.utils import AtomicCounter
29+
from vllm.utils.system_utils import kill_process_tree
2730

2831
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
2932
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
@@ -74,6 +77,7 @@ def __init__(self):
7477
self.adapter_registry = {}
7578
self.lora_id_counter = AtomicCounter(0)
7679
self.lora_requests = {}
80+
self._lora_lock = asyncio.Lock()
7781

7882
async def initialize(self, properties: dict):
7983
self.hf_configs = HuggingFaceProperties(**properties)
@@ -93,7 +97,7 @@ async def initialize(self, properties: dict):
9397
self.vllm_engine = AsyncLLMEngine.from_engine_args(
9498
self.vllm_engine_args)
9599
self.tokenizer = await self.vllm_engine.get_tokenizer()
96-
model_config = await self.vllm_engine.get_model_config()
100+
model_config = self.vllm_engine.model_config
97101

98102
model_names = self.vllm_engine_args.served_model_name or "lmi"
99103
if not isinstance(model_names, list):
@@ -108,19 +112,16 @@ async def initialize(self, properties: dict):
108112
self.model_name = model_names[0]
109113
self.model_registry = OpenAIServingModels(
110114
self.vllm_engine,
111-
model_config,
112115
base_model_paths,
113116
)
114117
self.completion_service = OpenAIServingCompletion(
115118
self.vllm_engine,
116-
model_config,
117119
self.model_registry,
118120
request_logger=None,
119121
)
120122

121123
self.chat_completion_service = OpenAIServingChat(
122124
self.vllm_engine,
123-
model_config,
124125
self.model_registry,
125126
"assistant",
126127
request_logger=None,
@@ -142,6 +143,9 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest:
142143
session = get_session(self.session_manager, raw_request)
143144
content_type = raw_request.get_property("Content-Type")
144145
decoded_payload = decode(raw_request, content_type)
146+
# Create a deep copy to prevent mutations from affecting the original
147+
decoded_payload = copy.deepcopy(decoded_payload)
148+
logger.info(f"Decoded payload after deepcopy: inputs={decoded_payload.get('inputs', 'N/A')}, stream={decoded_payload.get('stream', 'N/A')}")
145149

146150
adapter_name = _extract_lora_adapter(raw_request, decoded_payload)
147151

@@ -177,8 +181,10 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest:
177181
stream_output_formatter = vllm_stream_output_formatter
178182
# TGI request gets mapped to completions
179183
elif "inputs" in decoded_payload:
184+
logger.info(f"Before convert_lmi_schema: inputs={decoded_payload.get('inputs', 'N/A')}")
180185
vllm_request, include_details, include_prompt = convert_lmi_schema_to_completion_request(
181186
decoded_payload)
187+
logger.info(f"After convert_lmi_schema: vllm_request.prompt={vllm_request.prompt if hasattr(vllm_request, 'prompt') else 'N/A'}")
182188
vllm_invoke_function = self.completion_service.create_completion
183189
non_stream_output_formatter = lmi_with_details_non_stream_output_formatter if include_details else lmi_non_stream_output_formatter
184190
stream_output_formatter = lmi_with_details_stream_output_formatter if include_details else lmi_stream_output_formatter
@@ -242,20 +248,22 @@ async def inference(
242248
return output
243249

244250
if processed_request.lora_request:
251+
logger.info(f"Processing LoRA request: {processed_request.lora_request.lora_name}")
245252
original_add_request = self.vllm_engine.add_request
246253

247254
async def add_request_with_lora(*args, **kwargs):
248255
kwargs['lora_request'] = processed_request.lora_request
249256
return await original_add_request(*args, **kwargs)
250257

251258
self.vllm_engine.add_request = add_request_with_lora
252-
253-
try:
259+
try:
260+
response = await processed_request.inference_invoker(
261+
processed_request.vllm_request)
262+
finally:
263+
self.vllm_engine.add_request = original_add_request
264+
else:
254265
response = await processed_request.inference_invoker(
255266
processed_request.vllm_request)
256-
finally:
257-
if processed_request.lora_request:
258-
self.vllm_engine.add_request = original_add_request
259267

260268
if isinstance(response, types.AsyncGeneratorType):
261269
# Apply custom formatter to streaming response

serving/docker/lmi-container-requirements.txt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
torch==2.8.0
1+
torch==2.9.0
2+
autoawq
23
torchvision
34
peft==0.15.1
45
protobuf==4.25.1
5-
transformers==4.55.2
6+
transformers==4.56.0
67
hf-transfer
78
zstandard
89
datasets==3.0.1
@@ -25,12 +26,12 @@ sentence_transformers
2526
onnxruntime-gpu==1.20.0
2627
autoawq
2728
tokenizers
28-
pydantic==2.11.7
29+
pydantic>=2.12.0
2930
optimum==1.23.2
3031
uvloop
3132
ninja
3233
peft
3334
llmcompressor
34-
vllm==0.11.0
35+
vllm @ git+https://github.com/vllm-project/vllm.git
3536
xgrammar
36-
flashinfer-python==0.2.5
37+
flashinfer-python==0.4.1

serving/docker/lmi.Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ RUN scripts/patch_oss_dlc.sh python \
8989
&& apt-get clean -y && rm -rf /var/lib/apt/lists/*
9090

9191
COPY lmi-container-requirements.txt ./requirements.txt
92-
RUN pip3 install torch==2.8.0 torchvision \
92+
RUN pip3 install --upgrade pip setuptools
93+
RUN pip3 install torch==2.9.0 torchvision \
9394
&& pip3 install -r requirements.txt \
9495
&& pip3 install ${djl_converter_wheel} --no-deps
9596

tests/integration/llm/client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,7 +1800,7 @@ def test_handler_adapters(model, model_spec):
18001800
}
18011801
req["parameters"] = params
18021802
req["adapters"] = adapter
1803-
reqs.append(req)
1803+
reqs.append(req.copy())
18041804
for req in reqs:
18051805
for stream in stream_values:
18061806
req["stream"] = stream
@@ -1830,13 +1830,19 @@ def test_handler_adapters(model, model_spec):
18301830
LOGGER.info(f"del adapter {res}")
18311831
headers = {'content-type': 'application/json'}
18321832
endpoint = f"http://127.0.0.1:8080/invocations"
1833+
# Create a fresh copy to avoid using mutated request
1834+
import copy
1835+
req0_copy = copy.deepcopy(reqs[0])
18331836
res = requests.post(endpoint, headers=headers,
1834-
json=reqs[0]).content.decode("utf-8")
1837+
json=req0_copy).content.decode("utf-8")
18351838
LOGGER.info(f"call deleted adapter {res}")
18361839

18371840
if len(reqs) > 1:
1841+
# Create a fresh copy to avoid using mutated request
1842+
req1_copy = copy.deepcopy(reqs[1])
1843+
LOGGER.info(f"Request being sent: {req1_copy}")
18381844
res = requests.post(endpoint, headers=headers,
1839-
json=reqs[1]).content.decode("utf-8")
1845+
json=req1_copy).content.decode("utf-8")
18401846
LOGGER.info(f"call valid adapter after deletion {res}")
18411847
if not res or res.strip() == "":
18421848
LOGGER.error(f"Empty response received from model API: {res}")
@@ -1872,7 +1878,6 @@ def test_handler_adapters(model, model_spec):
18721878
LOGGER.error(msg)
18731879
raise RuntimeError(msg)
18741880

1875-
18761881
def test_handler_rolling_batch_chat(model, model_spec):
18771882
modelspec_checker(model, model_spec)
18781883
spec = model_spec[args.model]

0 commit comments

Comments
 (0)