Skip to content

Commit d0f8566

Browse files
committed
[0.34.0-dlc][python] Support adapters in async vLLM handler (#2901)
1 parent ba7830f commit d0f8566

File tree

10 files changed

+393
-84
lines changed

10 files changed

+393
-84
lines changed

.github/workflows/integration.yml

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,34 +145,36 @@ jobs:
145145
- test: TestAarch64
146146
instance: aarch64
147147
failure-prefix: aarch64
148-
- test: TestHfHandler
148+
# - test: TestHfHandler
149+
# instance: g6
150+
# failure-prefix: lmi
151+
# - test: TestTrtLlmHandler1
152+
# instance: g6
153+
# failure-prefix: trtllm
154+
# - test: TestTrtLlmHandler2
155+
# instance: g6
156+
# failure-prefix: trtllm
157+
- test: TestVllm1
149158
instance: g6
150159
failure-prefix: lmi
151-
- test: TestTrtLlmHandler1
152-
instance: g6
153-
failure-prefix: trtllm
154-
- test: TestTrtLlmHandler2
160+
- test: TestVllmCustomHandlers
155161
instance: g6
156-
failure-prefix: trtllm
157-
- test: TestVllm1
162+
failure-prefix: lmi
163+
- test: TestVllmLora
158164
instance: g6
159165
failure-prefix: lmi
160-
- test: TestVllmCustomHandlers
166+
- test: TestVllmAsyncLora
161167
instance: g6
162168
failure-prefix: lmi
163-
# - test: TestVllmLora
164-
# instance: g6
165-
# failure-prefix: lmi
166-
167169
- test: TestMultiModalVllm
168170
instance: g6
169171
failure-prefix: lmi
170172
# - test: TestTextEmbedding
171173
# instance: g6
172174
# failure-prefix: lmi
173-
- test: TestCorrectnessTrtLlm
174-
instance: g6
175-
failure-prefix: trtllm
175+
# - test: TestCorrectnessTrtLlm
176+
# instance: g6
177+
# failure-prefix: trtllm
176178
- test: TestStatefulModel
177179
instance: g6
178180
failure-prefix: lmi

engines/python/setup/djl_python/async_utils.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# the specific language governing permissions and limitations under the License.
1313
import json
1414
import logging
15+
import os
1516
from typing import AsyncGenerator, Callable, Optional, Union
1617

18+
from djl_python.inputs import Input
1719
from djl_python.outputs import Output
20+
from djl_python.input_parser import SAGEMAKER_ADAPTER_IDENTIFIER_HEADER
1821

1922

2023
def create_non_stream_output(data: Union[str, dict],
@@ -112,3 +115,159 @@ async def handle_streaming_response(
112115
yield output
113116
if last:
114117
return
118+
119+
120+
def _extract_lora_adapter(raw_request, decoded_payload):
121+
"""
122+
Get lora adapter name from request headers or payload.
123+
"""
124+
adapter_name = None
125+
126+
if SAGEMAKER_ADAPTER_IDENTIFIER_HEADER in raw_request.get_properties():
127+
adapter_name = raw_request.get_property(
128+
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER)
129+
logging.debug(f"Found adapter in headers: {adapter_name}")
130+
elif "adapter" in decoded_payload:
131+
adapter_name = decoded_payload.pop("adapter")
132+
logging.debug(f"Found adapter in payload: {adapter_name}")
133+
134+
return adapter_name
135+
136+
137+
async def register_adapter(inputs: Input, service):
138+
"""
139+
Registers lora adapter with the model.
140+
"""
141+
adapter_name = inputs.get_property("name")
142+
adapter_alias = inputs.get_property("alias") or adapter_name
143+
adapter_path = inputs.get_property("src")
144+
adapter_preload = inputs.get_as_string("preload").lower(
145+
) == "true" if inputs.contains_key("preload") else True
146+
adapter_pin = inputs.get_as_string(
147+
"pin").lower() == "true" if inputs.contains_key("pin") else False
148+
149+
outputs = Output()
150+
loaded = False
151+
try:
152+
if not os.path.exists(adapter_path):
153+
raise ValueError(
154+
f"Only local LoRA models are supported. {adapter_path} is not a valid path"
155+
)
156+
157+
if not adapter_preload and adapter_pin:
158+
raise ValueError("Can not set preload to false and pin to true")
159+
160+
if adapter_preload:
161+
loaded = await service.add_lora(adapter_name, adapter_alias,
162+
adapter_path)
163+
164+
if adapter_pin:
165+
await service.pin_lora(adapter_name, adapter_alias)
166+
service.adapter_registry[adapter_name] = inputs
167+
except Exception as e:
168+
logging.debug(f"Failed to register adapter: {e}", exc_info=True)
169+
if loaded:
170+
logging.info(
171+
f"LoRA adapter {adapter_alias} was successfully loaded, but failed to pin, unloading ..."
172+
)
173+
await service.remove_lora(adapter_name, adapter_alias)
174+
if any(msg in str(e)
175+
for msg in ("No free lora slots",
176+
"greater than the number of GPU LoRA slots")):
177+
raise MemoryError(str(e))
178+
err = {"data": "", "last": True, "code": 424, "error": str(e)}
179+
outputs.add(Output.binary_encode(err), key="data")
180+
return outputs
181+
182+
logging.info(
183+
f"Registered adapter {adapter_alias} from {adapter_path} successfully")
184+
result = {"data": f"Adapter {adapter_alias} registered"}
185+
outputs.add(Output.binary_encode(result), key="data")
186+
return outputs
187+
188+
189+
async def update_adapter(inputs: Input, service):
190+
"""
191+
Updates lora adapter with the model.
192+
"""
193+
adapter_name = inputs.get_property("name")
194+
adapter_alias = inputs.get_property("alias") or adapter_name
195+
adapter_path = inputs.get_property("src")
196+
adapter_preload = inputs.get_as_string("preload").lower(
197+
) == "true" if inputs.contains_key("preload") else True
198+
adapter_pin = inputs.get_as_string(
199+
"pin").lower() == "true" if inputs.contains_key("pin") else False
200+
201+
if adapter_name not in service.adapter_registry:
202+
raise ValueError(f"Adapter {adapter_alias} not registered.")
203+
204+
outputs = Output()
205+
try:
206+
if not adapter_preload and adapter_pin:
207+
raise ValueError("Can not set load to false and pin to true")
208+
209+
old_adapter = service.adapter_registry[adapter_name]
210+
old_adapter_path = old_adapter.get_property("src")
211+
if adapter_path != old_adapter_path:
212+
raise NotImplementedError(
213+
f"Updating adapter path is not supported.")
214+
215+
old_adapter_preload = old_adapter.get_as_string("preload").lower(
216+
) == "true" if old_adapter.contains_key("preload") else True
217+
if adapter_preload != old_adapter_preload:
218+
if adapter_preload:
219+
await service.add_lora(adapter_name, adapter_alias,
220+
adapter_path)
221+
else:
222+
await service.remove_lora(adapter_name, adapter_alias)
223+
224+
old_adapter_pin = old_adapter.get_as_string("pin").lower(
225+
) == "true" if old_adapter.contains_key("pin") else False
226+
if adapter_pin != old_adapter_pin:
227+
if adapter_pin:
228+
await service.pin_lora(adapter_name, adapter_alias)
229+
else:
230+
raise ValueError(
231+
f"Unpinning adapter is not supported. To unpin adapter '{adapter_alias}', please delete the adapter and re-register it without pinning."
232+
)
233+
service.adapter_registry[adapter_name] = inputs
234+
except Exception as e:
235+
logging.debug(f"Failed to update adapter: {e}", exc_info=True)
236+
if any(msg in str(e)
237+
for msg in ("No free lora slots",
238+
"greater than the number of GPU LoRA slots")):
239+
raise MemoryError(str(e))
240+
err = {"data": "", "last": True, "code": 424, "error": str(e)}
241+
outputs.add(Output.binary_encode(err), key="data")
242+
return outputs
243+
244+
logging.info(f"Updated adapter {adapter_alias} successfully")
245+
result = {"data": f"Adapter {adapter_alias} updated"}
246+
outputs.add(Output.binary_encode(result), key="data")
247+
return outputs
248+
249+
250+
async def unregister_adapter(inputs: Input, service):
251+
"""
252+
Unregisters lora adapter from the model.
253+
"""
254+
adapter_name = inputs.get_property("name")
255+
adapter_alias = inputs.get_property("alias") or adapter_name
256+
257+
if adapter_name not in service.adapter_registry:
258+
raise ValueError(f"Adapter {adapter_alias} not registered.")
259+
260+
outputs = Output()
261+
try:
262+
await service.remove_lora(adapter_name, adapter_alias)
263+
del service.adapter_registry[adapter_name]
264+
except Exception as e:
265+
logging.debug(f"Failed to unregister adapter: {e}", exc_info=True)
266+
err = {"data": "", "last": True, "code": 424, "error": str(e)}
267+
outputs.add(Output.binary_encode(err), key="data")
268+
return outputs
269+
270+
logging.info(f"Unregistered adapter {adapter_alias} successfully")
271+
result = {"data": f"Adapter {adapter_alias} unregistered"}
272+
outputs.add(Output.binary_encode(result), key="data")
273+
return outputs

engines/python/setup/djl_python/huggingface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,9 @@ def update_adapter(inputs: Input):
604604
if adapter_pin:
605605
_service.pin_lora(adapter_name, adapter_alias)
606606
else:
607-
raise NotImplementedError(f"Unpin adapter is not supported.")
607+
raise ValueError(
608+
f"Unpinning adapter is not supported. To unpin adapter '{adapter_alias}', please delete the adapter and re-register it without pinning."
609+
)
608610
_service.adapter_registry[adapter_name] = inputs
609611
except Exception as e:
610612
logging.debug(f"Failed to update adapter: {e}", exc_info=True)

engines/python/setup/djl_python/input_parser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from djl_python.request_io import TextInput, RequestInput
2323
from djl_python.three_p.three_p_utils import parse_3p_request
2424

25+
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER = "X-Amzn-SageMaker-Adapter-Identifier"
26+
2527

2628
def input_formatter(function):
2729
"""
@@ -237,9 +239,9 @@ def _fetch_adapters_from_input(input_map: dict, input_item: Input,
237239

238240
# check properties, possible from header
239241
adapter_alias = None
240-
if "X-Amzn-SageMaker-Adapter-Identifier" in input_item.get_properties():
242+
if SAGEMAKER_ADAPTER_IDENTIFIER_HEADER in input_item.get_properties():
241243
adapters_per_item = input_item.get_property(
242-
"X-Amzn-SageMaker-Adapter-Identifier")
244+
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER)
243245
adapter_alias = input_item.get_property(
244246
"X-Amzn-SageMaker-Adapter-Alias")
245247

engines/python/setup/djl_python/lmi_vllm/request_response_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
self.stream_output_formatter = stream_output_formatter
4848
self.accumulate_chunks = accumulate_chunks
4949
self.include_prompt = include_prompt
50+
self.lora_request = None
5051

5152

5253
def convert_lmi_schema_to_completion_request(

0 commit comments

Comments
 (0)