Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,34 +145,36 @@ jobs:
- test: TestAarch64
instance: aarch64
failure-prefix: aarch64
- test: TestHfHandler
# - test: TestHfHandler
# instance: g6
# failure-prefix: lmi
# - test: TestTrtLlmHandler1
# instance: g6
# failure-prefix: trtllm
# - test: TestTrtLlmHandler2
# instance: g6
# failure-prefix: trtllm
- test: TestVllm1
instance: g6
failure-prefix: lmi
- test: TestTrtLlmHandler1
instance: g6
failure-prefix: trtllm
- test: TestTrtLlmHandler2
- test: TestVllmCustomHandlers
instance: g6
failure-prefix: trtllm
- test: TestVllm1
failure-prefix: lmi
- test: TestVllmLora
instance: g6
failure-prefix: lmi
- test: TestVllmCustomHandlers
- test: TestVllmAsyncLora
instance: g6
failure-prefix: lmi
# - test: TestVllmLora
# instance: g6
# failure-prefix: lmi

- test: TestMultiModalVllm
instance: g6
failure-prefix: lmi
# - test: TestTextEmbedding
# instance: g6
# failure-prefix: lmi
- test: TestCorrectnessTrtLlm
instance: g6
failure-prefix: trtllm
# - test: TestCorrectnessTrtLlm
# instance: g6
# failure-prefix: trtllm
- test: TestStatefulModel
instance: g6
failure-prefix: lmi
Expand Down
6 changes: 4 additions & 2 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from djl_python.request_io import TextInput, RequestInput
from djl_python.three_p.three_p_utils import parse_3p_request

SAGEMAKER_ADAPTER_IDENTIFIER_HEADER = "X-Amzn-SageMaker-Adapter-Identifier"


def input_formatter(function):
"""
Expand Down Expand Up @@ -237,9 +239,9 @@ def _fetch_adapters_from_input(input_map: dict, input_item: Input,

# check properties, possible from header
adapter_alias = None
if "X-Amzn-SageMaker-Adapter-Identifier" in input_item.get_properties():
if SAGEMAKER_ADAPTER_IDENTIFIER_HEADER in input_item.get_properties():
adapters_per_item = input_item.get_property(
"X-Amzn-SageMaker-Adapter-Identifier")
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER)
adapter_alias = input_item.get_property(
"X-Amzn-SageMaker-Adapter-Alias")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self.stream_output_formatter = stream_output_formatter
self.accumulate_chunks = accumulate_chunks
self.include_prompt = include_prompt
self.lora_request = None


def convert_lmi_schema_to_completion_request(
Expand Down
219 changes: 215 additions & 4 deletions engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import OpenAIServingModels, BaseModelPath
from vllm.utils import kill_process_tree
from vllm.utils import kill_process_tree, AtomicCounter

from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.encode_decode import decode
from djl_python.async_utils import handle_streaming_response, create_non_stream_output
from djl_python.custom_formatter_handling import CustomFormatterError, CustomFormatterHandler
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
from djl_python.rolling_batch.rolling_batch_vllm_utils import create_lora_request, get_lora_request
from djl_python.input_parser import SAGEMAKER_ADAPTER_IDENTIFIER_HEADER

from djl_python.lmi_vllm.request_response_utils import (
ProcessedRequest,
vllm_stream_output_formatter,
Expand Down Expand Up @@ -67,6 +70,9 @@ def __init__(self):
self.vllm_properties = None
self.model_name = None
self.initialized = False
self.adapter_registry = {}
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}

async def initialize(self, properties: dict):
self.hf_configs = HuggingFaceProperties(**properties)
Expand Down Expand Up @@ -136,6 +142,8 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest:
content_type = raw_request.get_property("Content-Type")
decoded_payload = decode(raw_request, content_type)

adapter_name = self._extract_lora_adapter(raw_request, decoded_payload)

# Apply input formatter
decoded_payload = self.apply_input_formatter(decoded_payload,
tokenizer=self.tokenizer)
Expand All @@ -148,6 +156,18 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest:
# completions/chat completions require model in the payload
if "model" not in decoded_payload:
decoded_payload["model"] = self.model_name

lora_request = None
if adapter_name:
if adapter_name not in self.lora_requests:
raise ValueError(
f"LoRA adapter {adapter_name} not found in registry. Available adapters: {list(self.lora_requests.keys())}"
)
lora_request = get_lora_request(adapter_name, self.lora_requests)
logging.info(
f"Using LoRA request: {lora_request.lora_name} (ID: {lora_request.lora_int_id})"
)

# completions request
if "prompt" in decoded_payload:
vllm_request = CompletionRequest(**decoded_payload)
Expand Down Expand Up @@ -193,6 +213,7 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest:
accumulate_chunks,
include_prompt,
)
processed_request.lora_request = lora_request
return processed_request

async def check_health(self):
Expand All @@ -219,8 +240,21 @@ async def inference(
"", error=f"Input parsing failed: {str(e)}", code=424)
return output

response = await processed_request.inference_invoker(
processed_request.vllm_request)
if processed_request.lora_request:
original_add_request = self.vllm_engine.add_request

async def add_request_with_lora(*args, **kwargs):
kwargs['lora_request'] = processed_request.lora_request
return await original_add_request(*args, **kwargs)

self.vllm_engine.add_request = add_request_with_lora

try:
response = await processed_request.inference_invoker(
processed_request.vllm_request)
finally:
if processed_request.lora_request:
self.vllm_engine.add_request = original_add_request

if isinstance(response, types.AsyncGeneratorType):
# Apply custom formatter to streaming response
Expand All @@ -244,6 +278,46 @@ async def inference(
tokenizer=self.tokenizer,
)

async def add_lora(self, lora_name: str, lora_alias: str, lora_path: str):
logging.info(f"Adding LoRA {lora_name} from {lora_path}")
lora_id = self.lora_id_counter.inc(1)
lora_request = create_lora_request(lora_name, lora_id, lora_path, None)
self.lora_requests[lora_request.lora_name] = lora_request
result = await self.vllm_engine.add_lora(lora_request)
logging.info(f"LoRA {lora_name} added to engine: {result}")
return result

async def remove_lora(self, lora_name: str, lora_alias: str):
logging.info(f"Removing LoRA {lora_name}")
if lora_name not in self.lora_requests:
raise ValueError(f"LoRA adapter {lora_name} not found in registry")
lora_request = get_lora_request(lora_name, self.lora_requests)
result = await self.vllm_engine.remove_lora(lora_request.lora_int_id)
del self.lora_requests[lora_name]
return result

async def pin_lora(self, lora_name: str, lora_alias: str):
lora_request = get_lora_request(lora_name, self.lora_requests)
loaded = await self.vllm_engine.add_lora(lora_request)
return loaded and await self.vllm_engine.pin_lora(
lora_request.lora_int_id)

def _extract_lora_adapter(self, raw_request, decoded_payload):
"""
Get lora adapter name from request headers or payload.
"""
adapter_name = None

if SAGEMAKER_ADAPTER_IDENTIFIER_HEADER in raw_request.get_properties():
adapter_name = raw_request.get_property(
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER)
logging.debug(f"Found adapter in headers: {adapter_name}")
elif "adapter" in decoded_payload:
adapter_name = decoded_payload.pop("adapter")
logging.debug(f"Found adapter in payload: {adapter_name}")

return adapter_name


service = VLLMHandler()

Expand All @@ -259,3 +333,140 @@ async def handle(

outputs = await service.inference(inputs)
return outputs


async def register_adapter(inputs: Input):
"""
Registers lora adapter with the model.
"""
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_preload = inputs.get_as_string("preload").lower(
) == "true" if inputs.contains_key("preload") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

outputs = Output()
loaded = False
try:
if not os.path.exists(adapter_path):
raise ValueError(
f"Only local LoRA models are supported. {adapter_path} is not a valid path"
)

if not adapter_preload and adapter_pin:
raise ValueError("Can not set preload to false and pin to true")

if adapter_preload:
loaded = await service.add_lora(adapter_name, adapter_alias,
adapter_path)

if adapter_pin:
await service.pin_lora(adapter_name, adapter_alias)
service.adapter_registry[adapter_name] = inputs
except Exception as e:
logging.debug(f"Failed to register adapter: {e}", exc_info=True)
if loaded:
logging.info(
f"LoRA adapter {adapter_alias} was successfully loaded, but failed to pin, unloading ..."
)
await service.remove_lora(adapter_name, adapter_alias)
if any(msg in str(e)
for msg in ("No free lora slots",
"greater than the number of GPU LoRA slots")):
raise MemoryError(str(e))
err = {"data": "", "last": True, "code": 424, "error": str(e)}
outputs.add(Output.binary_encode(err), key="data")
return outputs

logging.info(
f"Registered adapter {adapter_alias} from {adapter_path} successfully")
result = {"data": f"Adapter {adapter_alias} registered"}
outputs.add(Output.binary_encode(result), key="data")
return outputs


async def update_adapter(inputs: Input):
"""
Updates lora adapter with the model.
"""
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_preload = inputs.get_as_string("preload").lower(
) == "true" if inputs.contains_key("preload") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

if adapter_name not in service.adapter_registry:
raise ValueError(f"Adapter {adapter_alias} not registered.")

outputs = Output()
try:
if not adapter_preload and adapter_pin:
raise ValueError("Can not set load to false and pin to true")

old_adapter = service.adapter_registry[adapter_name]
old_adapter_path = old_adapter.get_property("src")
if adapter_path != old_adapter_path:
raise NotImplementedError(
f"Updating adapter path is not supported.")

old_adapter_preload = old_adapter.get_as_string("preload").lower(
) == "true" if old_adapter.contains_key("preload") else True
if adapter_preload != old_adapter_preload:
if adapter_preload:
await service.add_lora(adapter_name, adapter_alias,
adapter_path)
else:
await service.remove_lora(adapter_name, adapter_alias)

old_adapter_pin = old_adapter.get_as_string("pin").lower(
) == "true" if old_adapter.contains_key("pin") else False
if adapter_pin != old_adapter_pin:
if adapter_pin:
await service.pin_lora(adapter_name, adapter_alias)
else:
raise NotImplementedError(f"Unpin adapter is not supported.")
service.adapter_registry[adapter_name] = inputs
except Exception as e:
logging.debug(f"Failed to update adapter: {e}", exc_info=True)
if any(msg in str(e)
for msg in ("No free lora slots",
"greater than the number of GPU LoRA slots")):
raise MemoryError(str(e))
err = {"data": "", "last": True, "code": 424, "error": str(e)}
outputs.add(Output.binary_encode(err), key="data")
return outputs

logging.info(f"Updated adapter {adapter_alias} successfully")
result = {"data": f"Adapter {adapter_alias} updated"}
outputs.add(Output.binary_encode(result), key="data")
return outputs


async def unregister_adapter(inputs: Input):
"""
Unregisters lora adapter from the model.
"""
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name

if adapter_name not in service.adapter_registry:
raise ValueError(f"Adapter {adapter_alias} not registered.")

outputs = Output()
try:
await service.remove_lora(adapter_name, adapter_alias)
del service.adapter_registry[adapter_name]
except Exception as e:
logging.debug(f"Failed to unregister adapter: {e}", exc_info=True)
err = {"data": "", "last": True, "code": 424, "error": str(e)}
outputs.add(Output.binary_encode(err), key="data")
return outputs

logging.info(f"Unregistered adapter {adapter_alias} successfully")
result = {"data": f"Adapter {adapter_alias} unregistered"}
outputs.add(Output.binary_encode(result), key="data")
return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ Output predict(Input inputs, int timeout, boolean initialLoad) throws TranslateE
// In RollingBatch, we queue adapter loading jobs to occur after the initial load.
// Executing those in RollingBatch context doesn't work, so we need to handle them in the
// 'standard' way.
if (initialLoad || inputs.getProperty("handler", null) != null) {
if (initialLoad
|| (inputs.getProperty("handler", null) != null && asyncRequestManager == null)) {
return predictStandard(inputs, timeout, initialLoad);
}
if (rollingBatch != null) {
Expand Down
Loading
Loading